Skip to content

Commit

Permalink
8310268: RISC-V: misaligned memory access in String.Compare intrinsic
Browse files Browse the repository at this point in the history
Co-authored-by: Feilong Jiang <fjiang@openjdk.org>
Reviewed-by: fyang
  • Loading branch information
Vladimir Kempik and feilongjiang committed Jul 28, 2023
1 parent 402cb6a commit d6245b6
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 120 deletions.
34 changes: 13 additions & 21 deletions src/hotspot/cpu/riscv/c2_MacroAssembler_riscv.cpp
Expand Up @@ -868,9 +868,10 @@ void C2_MacroAssembler::string_compare(Register str1, Register str2,
// load first parts of strings and finish initialization while loading
{
if (str1_isL == str2_isL) { // LL or UU
// check if str1 and str2 is same pointer
beq(str1, str2, DONE);
// load 8 bytes once to compare
ld(tmp1, Address(str1));
beq(str1, str2, DONE);
ld(tmp2, Address(str2));
mv(t0, STUB_THRESHOLD);
bge(cnt2, t0, STUB);
Expand Down Expand Up @@ -913,9 +914,8 @@ void C2_MacroAssembler::string_compare(Register str1, Register str2,
addi(cnt1, cnt1, 8);
}
addi(cnt2, cnt2, isUL ? 4 : 8);
bne(tmp1, tmp2, DIFFERENCE);
bgez(cnt2, TAIL);
xorr(tmp3, tmp1, tmp2);
bnez(tmp3, DIFFERENCE);

// main loop
bind(NEXT_WORD);
Expand Down Expand Up @@ -944,38 +944,30 @@ void C2_MacroAssembler::string_compare(Register str1, Register str2,
addi(cnt1, cnt1, 8);
addi(cnt2, cnt2, 4);
}
bgez(cnt2, TAIL);

xorr(tmp3, tmp1, tmp2);
beqz(tmp3, NEXT_WORD);
j(DIFFERENCE);
bne(tmp1, tmp2, DIFFERENCE);
bltz(cnt2, NEXT_WORD);
bind(TAIL);
xorr(tmp3, tmp1, tmp2);
bnez(tmp3, DIFFERENCE);
// Last longword. In the case where length == 4 we compare the
// same longword twice, but that's still faster than another
// conditional branch.
if (str1_isL == str2_isL) { // LL or UU
ld(tmp1, Address(str1));
ld(tmp2, Address(str2));
load_long_misaligned(tmp1, Address(str1), tmp3, isLL ? 1 : 2);
load_long_misaligned(tmp2, Address(str2), tmp3, isLL ? 1 : 2);
} else if (isLU) { // LU case
lwu(tmp1, Address(str1));
ld(tmp2, Address(str2));
load_int_misaligned(tmp1, Address(str1), tmp3, false);
load_long_misaligned(tmp2, Address(str2), tmp3, 2);
inflate_lo32(tmp3, tmp1);
mv(tmp1, tmp3);
} else { // UL case
lwu(tmp2, Address(str2));
ld(tmp1, Address(str1));
load_int_misaligned(tmp2, Address(str2), tmp3, false);
load_long_misaligned(tmp1, Address(str1), tmp3, 2);
inflate_lo32(tmp3, tmp2);
mv(tmp2, tmp3);
}
bind(TAIL_CHECK);
xorr(tmp3, tmp1, tmp2);
beqz(tmp3, DONE);
beq(tmp1, tmp2, DONE);

// Find the first different characters in the longwords and
// compute their difference.
bind(DIFFERENCE);
xorr(tmp3, tmp1, tmp2);
ctzc_bit(result, tmp3, isLL); // count zero from lsb to msb
srl(tmp1, tmp1, result);
srl(tmp2, tmp2, result);
Expand Down
30 changes: 10 additions & 20 deletions src/hotspot/cpu/riscv/macroAssembler_riscv.cpp
Expand Up @@ -3967,18 +3967,17 @@ void MacroAssembler::ctzc_bit(Register Rd, Register Rs, bool isLL, Register tmp1
void MacroAssembler::inflate_lo32(Register Rd, Register Rs, Register tmp1, Register tmp2) {
assert_different_registers(Rd, Rs, tmp1, tmp2);

mv(tmp1, 0xFF);
mv(Rd, zr);
for (int i = 0; i <= 3; i++) {
mv(tmp1, 0xFF000000); // first byte mask at lower word
andr(Rd, Rs, tmp1);
for (int i = 0; i < 2; i++) {
slli(Rd, Rd, wordSize);
srli(tmp1, tmp1, wordSize);
andr(tmp2, Rs, tmp1);
if (i) {
slli(tmp2, tmp2, i * 8);
}
orr(Rd, Rd, tmp2);
if (i != 3) {
slli(tmp1, tmp1, 8);
}
}
slli(Rd, Rd, wordSize);
andi(tmp2, Rs, 0xFF); // last byte mask at lower word
orr(Rd, Rd, tmp2);
}

// This instruction reads adjacent 4 bytes from the upper half of source register,
Expand All @@ -3987,17 +3986,8 @@ void MacroAssembler::inflate_lo32(Register Rd, Register Rs, Register tmp1, Regis
// Rd: 00A700A600A500A4
void MacroAssembler::inflate_hi32(Register Rd, Register Rs, Register tmp1, Register tmp2) {
assert_different_registers(Rd, Rs, tmp1, tmp2);

mv(tmp1, 0xFF00000000);
mv(Rd, zr);
for (int i = 0; i <= 3; i++) {
andr(tmp2, Rs, tmp1);
orr(Rd, Rd, tmp2);
srli(Rd, Rd, 8);
if (i != 3) {
slli(tmp1, tmp1, 8);
}
}
srli(Rs, Rs, 32); // only upper 32 bits are needed
inflate_lo32(Rd, Rs, tmp1, tmp2);
}

// The size of the blocks erased by the zero_blocks stub. We must
Expand Down
145 changes: 72 additions & 73 deletions src/hotspot/cpu/riscv/stubGenerator_riscv.cpp
Expand Up @@ -2275,24 +2275,21 @@ class StubGenerator: public StubCodeGenerator {
}

// code for comparing 8 characters of strings with Latin1 and Utf16 encoding
void compare_string_8_x_LU(Register tmpL, Register tmpU, Label &DIFF1,
Label &DIFF2) {
const Register strU = x12, curU = x7, strL = x29, tmp = x30;
__ ld(tmpL, Address(strL));
__ addi(strL, strL, 8);
void compare_string_8_x_LU(Register tmpL, Register tmpU, Register strL, Register strU, Label& DIFF) {
const Register tmp = x30, tmpLval = x12;
__ ld(tmpLval, Address(strL));
__ addi(strL, strL, wordSize);
__ ld(tmpU, Address(strU));
__ addi(strU, strU, 8);
__ inflate_lo32(tmp, tmpL);
__ mv(t0, tmp);
__ xorr(tmp, curU, t0);
__ bnez(tmp, DIFF2);

__ ld(curU, Address(strU));
__ addi(strU, strU, 8);
__ inflate_hi32(tmp, tmpL);
__ mv(t0, tmp);
__ xorr(tmp, tmpU, t0);
__ bnez(tmp, DIFF1);
__ addi(strU, strU, wordSize);
__ inflate_lo32(tmpL, tmpLval);
__ xorr(tmp, tmpU, tmpL);
__ bnez(tmp, DIFF);

__ ld(tmpU, Address(strU));
__ addi(strU, strU, wordSize);
__ inflate_hi32(tmpL, tmpLval);
__ xorr(tmp, tmpU, tmpL);
__ bnez(tmp, DIFF);
}

// x10 = result
Expand All @@ -2307,89 +2304,91 @@ class StubGenerator: public StubCodeGenerator {
__ align(CodeEntryAlignment);
StubCodeMark mark(this, "StubRoutines", isLU ? "compare_long_string_different_encoding LU" : "compare_long_string_different_encoding UL");
address entry = __ pc();
Label SMALL_LOOP, TAIL, TAIL_LOAD_16, LOAD_LAST, DIFF1, DIFF2,
DONE, CALCULATE_DIFFERENCE;
const Register result = x10, str1 = x11, cnt1 = x12, str2 = x13, cnt2 = x14,
tmp1 = x28, tmp2 = x29, tmp3 = x30, tmp4 = x7, tmp5 = x31;
RegSet spilled_regs = RegSet::of(tmp4, tmp5);
Label SMALL_LOOP, TAIL, LOAD_LAST, DONE, CALCULATE_DIFFERENCE;
const Register result = x10, str1 = x11, str2 = x13, cnt2 = x14,
tmp1 = x28, tmp2 = x29, tmp3 = x30, tmp4 = x12;

// cnt2 == amount of characters left to compare
// Check already loaded first 4 symbols
__ inflate_lo32(tmp3, isLU ? tmp1 : tmp2);
__ mv(isLU ? tmp1 : tmp2, tmp3);
__ addi(str1, str1, isLU ? wordSize / 2 : wordSize);
__ addi(str2, str2, isLU ? wordSize : wordSize / 2);
__ sub(cnt2, cnt2, 8); // Already loaded 4 symbols. Last 4 is special case.
__ push_reg(spilled_regs, sp);
__ sub(cnt2, cnt2, wordSize / 2); // Already loaded 4 symbols

if (isLU) {
__ add(str1, str1, cnt2);
__ shadd(str2, cnt2, str2, t0, 1);
} else {
__ shadd(str1, cnt2, str1, t0, 1);
__ add(str2, str2, cnt2);
}
__ xorr(tmp3, tmp1, tmp2);
__ mv(tmp5, tmp2);
__ bnez(tmp3, CALCULATE_DIFFERENCE);

Register strU = isLU ? str2 : str1,
strL = isLU ? str1 : str2,
tmpU = isLU ? tmp5 : tmp1, // where to keep U for comparison
tmpL = isLU ? tmp1 : tmp5; // where to keep L for comparison
tmpU = isLU ? tmp2 : tmp1, // where to keep U for comparison
tmpL = isLU ? tmp1 : tmp2; // where to keep L for comparison

__ sub(tmp2, strL, cnt2); // strL pointer to load from
__ slli(t0, cnt2, 1);
__ sub(cnt1, strU, t0); // strU pointer to load from
// make sure main loop is 8 byte-aligned, we should load another 4 bytes from strL
// cnt2 is >= 68 here, no need to check it for >= 0
__ lwu(tmpL, Address(strL));
__ addi(strL, strL, wordSize / 2);
__ ld(tmpU, Address(strU));
__ addi(strU, strU, wordSize);
__ inflate_lo32(tmp3, tmpL);
__ mv(tmpL, tmp3);
__ xorr(tmp3, tmpU, tmpL);
__ bnez(tmp3, CALCULATE_DIFFERENCE);
__ addi(cnt2, cnt2, -wordSize / 2);

__ ld(tmp4, Address(cnt1));
__ addi(cnt1, cnt1, 8);
__ beqz(cnt2, LOAD_LAST); // no characters left except last load
__ sub(cnt2, cnt2, 16);
// we are now 8-bytes aligned on strL
__ sub(cnt2, cnt2, wordSize * 2);
__ bltz(cnt2, TAIL);
__ bind(SMALL_LOOP); // smaller loop
__ sub(cnt2, cnt2, 16);
compare_string_8_x_LU(tmpL, tmpU, DIFF1, DIFF2);
compare_string_8_x_LU(tmpL, tmpU, DIFF1, DIFF2);
__ sub(cnt2, cnt2, wordSize * 2);
compare_string_8_x_LU(tmpL, tmpU, strL, strU, CALCULATE_DIFFERENCE);
compare_string_8_x_LU(tmpL, tmpU, strL, strU, CALCULATE_DIFFERENCE);
__ bgez(cnt2, SMALL_LOOP);
__ addi(t0, cnt2, 16);
__ beqz(t0, LOAD_LAST);
__ bind(TAIL); // 1..15 characters left until last load (last 4 characters)
// Address of 8 bytes before last 4 characters in UTF-16 string
__ shadd(cnt1, cnt2, cnt1, t0, 1);
// Address of 16 bytes before last 4 characters in Latin1 string
__ add(tmp2, tmp2, cnt2);
__ ld(tmp4, Address(cnt1, -8));
// last 16 characters before last load
compare_string_8_x_LU(tmpL, tmpU, DIFF1, DIFF2);
compare_string_8_x_LU(tmpL, tmpU, DIFF1, DIFF2);
__ j(LOAD_LAST);
__ bind(DIFF2);
__ mv(tmpU, tmp4);
__ bind(DIFF1);
__ mv(tmpL, t0);
__ j(CALCULATE_DIFFERENCE);
__ bind(LOAD_LAST);
// Last 4 UTF-16 characters are already pre-loaded into tmp4 by compare_string_8_x_LU.
// No need to load it again
__ mv(tmpU, tmp4);
__ ld(tmpL, Address(strL));
__ addi(t0, cnt2, wordSize * 2);
__ beqz(t0, DONE);
__ bind(TAIL); // 1..15 characters left
// Aligned access. Load bytes in portions - 4, 2, 1.

__ addi(t0, cnt2, wordSize);
__ addi(cnt2, cnt2, wordSize * 2); // amount of characters left to process
__ bltz(t0, LOAD_LAST);
// remaining characters are greater than or equals to 8, we can do one compare_string_8_x_LU
compare_string_8_x_LU(tmpL, tmpU, strL, strU, CALCULATE_DIFFERENCE);
__ addi(cnt2, cnt2, -wordSize);
__ beqz(cnt2, DONE); // no character left
__ bind(LOAD_LAST); // cnt2 = 1..7 characters left

__ addi(cnt2, cnt2, -wordSize); // cnt2 is now an offset in strL which points to last 8 bytes
__ slli(t0, cnt2, 1); // t0 is now an offset in strU which points to last 16 bytes
__ add(strL, strL, cnt2); // Address of last 8 bytes in Latin1 string
__ add(strU, strU, t0); // Address of last 16 bytes in UTF-16 string
__ load_int_misaligned(tmpL, Address(strL), t0, false);
__ load_long_misaligned(tmpU, Address(strU), t0, 2);
__ inflate_lo32(tmp3, tmpL);
__ mv(tmpL, tmp3);
__ xorr(tmp3, tmpU, tmpL);
__ beqz(tmp3, DONE);
__ bnez(tmp3, CALCULATE_DIFFERENCE);

__ addi(strL, strL, wordSize / 2); // Address of last 4 bytes in Latin1 string
__ addi(strU, strU, wordSize); // Address of last 8 bytes in UTF-16 string
__ load_int_misaligned(tmpL, Address(strL), t0, false);
__ load_long_misaligned(tmpU, Address(strU), t0, 2);
__ inflate_lo32(tmp3, tmpL);
__ mv(tmpL, tmp3);
__ xorr(tmp3, tmpU, tmpL);
__ bnez(tmp3, CALCULATE_DIFFERENCE);
__ j(DONE); // no character left

// Find the first different characters in the longwords and
// compute their difference.
__ bind(CALCULATE_DIFFERENCE);
__ ctzc_bit(tmp4, tmp3);
__ srl(tmp1, tmp1, tmp4);
__ srl(tmp5, tmp5, tmp4);
__ srl(tmp2, tmp2, tmp4);
__ andi(tmp1, tmp1, 0xFFFF);
__ andi(tmp5, tmp5, 0xFFFF);
__ sub(result, tmp1, tmp5);
__ andi(tmp2, tmp2, 0xFFFF);
__ sub(result, tmp1, tmp2);
__ bind(DONE);
__ pop_reg(spilled_regs, sp);
__ ret();
return entry;
}
Expand Down Expand Up @@ -2502,9 +2501,9 @@ class StubGenerator: public StubCodeGenerator {
__ xorr(tmp4, tmp1, tmp2);
__ bnez(tmp4, DIFF);
__ add(str1, str1, cnt2);
__ ld(tmp5, Address(str1));
__ load_long_misaligned(tmp5, Address(str1), tmp3, isLL ? 1 : 2);
__ add(str2, str2, cnt2);
__ ld(cnt1, Address(str2));
__ load_long_misaligned(cnt1, Address(str2), tmp3, isLL ? 1 : 2);
__ xorr(tmp4, tmp5, cnt1);
__ beqz(tmp4, LENGTH_DIFF);
// Find the first different characters in the longwords and
Expand Down
Expand Up @@ -24,20 +24,20 @@

/*
* @test
* @requires os.arch=="aarch64"
* @requires os.arch=="aarch64" | os.arch=="riscv64"
* @summary String::compareTo implementation uses different algorithms for
* different string length. This test creates string with specified
* size and longer string, which is same at beginning.
* Expecting length delta to be returned. Test class takes 2
* parameters: <string length>, <maximum string length delta>
* Input parameters for this test are set according to Aarch64
* Input parameters for this test are set according to Aarch64/RISC-V
* String::compareTo intrinsic implementation specifics. Aarch64
* implementation has 1, 4, 8 -bytes loops for length < 72 and
* 16, 32, 64 -characters loops for length >= 72. Code is also affected
* 16, 32, 64 -characters loops for length >= 72. Aarch64 Code is also affected
* by SoftwarePrefetchHintDistance vm flag value.
* @run main/othervm -XX:SoftwarePrefetchHintDistance=192 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90 192 193 208 209
* @run main/othervm -XX:SoftwarePrefetchHintDistance=16 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90
* @run main/othervm -XX:SoftwarePrefetchHintDistance=-1 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90
* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:SoftwarePrefetchHintDistance=192 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90 192 193 208 209
* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:SoftwarePrefetchHintDistance=16 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90
* @run main/othervm -XX:+IgnoreUnrecognizedVMOptions -XX:SoftwarePrefetchHintDistance=-1 compiler.intrinsics.string.TestStringCompareToDifferentLength 4 2 5 10 13 17 20 23 24 25 71 72 73 88 90
*/

package compiler.intrinsics.string;
Expand Down

3 comments on commit d6245b6

@openjdk-notifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VladimirKempik
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/backport jdk21u

@openjdk
Copy link

@openjdk openjdk bot commented on d6245b6 Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VladimirKempik the backport was successfully created on the branch VladimirKempik-backport-d6245b68 in my personal fork of openjdk/jdk21u. To create a pull request with this backport targeting openjdk/jdk21u:master, just click the following link:

➡️ Create pull request

The title of the pull request is automatically filled in correctly and below you find a suggestion for the pull request body:

Hi all,

This pull request contains a backport of commit d6245b68 from the openjdk/jdk repository.

The commit being backported was authored by Vladimir Kempik on 28 Jul 2023 and was reviewed by Fei Yang.

Thanks!

If you need to update the source branch of the pull then run the following commands in a local clone of your personal fork of openjdk/jdk21u:

$ git fetch https://github.com/openjdk-bots/jdk21u.git VladimirKempik-backport-d6245b68:VladimirKempik-backport-d6245b68
$ git checkout VladimirKempik-backport-d6245b68
# make changes
$ git add paths/to/changed/files
$ git commit --message 'Describe additional changes made'
$ git push https://github.com/openjdk-bots/jdk21u.git VladimirKempik-backport-d6245b68

Please sign in to comment.