Skip to content

Commit 4fd24a6

Browse files
committed
[GR-66145] Improve the intrinsification of Vector::slice
PullRequest: graal/21422
2 parents 4d8b719 + a4b799d commit 4fd24a6

File tree

7 files changed

+349
-2
lines changed

7 files changed

+349
-2
lines changed

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/lir/amd64/vector/AMD64VectorShuffle.java

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
import jdk.vm.ci.code.Register;
9898
import jdk.vm.ci.meta.AllocatableValue;
9999
import jdk.vm.ci.meta.JavaConstant;
100+
import jdk.vm.ci.meta.PlatformKind;
100101
import jdk.vm.ci.meta.Value;
101102

102103
public class AMD64VectorShuffle {
@@ -360,6 +361,132 @@ private void emitBytePermute(CompilationResultBuilder crb, AMD64MacroAssembler m
360361
}
361362
}
362363

364+
/**
365+
* A slice operation, see {@link jdk.graal.compiler.vector.nodes.amd64.AMD64SimdSliceNode}.
366+
*/
367+
public static final class SliceOp extends AMD64LIRInstruction {
368+
public static final LIRInstructionClass<SliceOp> TYPE = LIRInstructionClass.create(SliceOp.class);
369+
370+
@Def({OperandFlag.REG}) protected AllocatableValue result;
371+
@Alive({OperandFlag.REG}) protected AllocatableValue src1;
372+
@Alive({OperandFlag.REG}) protected AllocatableValue src2;
373+
@Temp({OperandFlag.REG, OperandFlag.ILLEGAL}) protected AllocatableValue tmp1;
374+
@Temp({OperandFlag.REG, OperandFlag.ILLEGAL}) protected AllocatableValue tmp2;
375+
private final int originInBytes;
376+
private final AMD64SIMDInstructionEncoding encoding;
377+
378+
public SliceOp(AMD64LIRGenerator gen, AllocatableValue result, AllocatableValue src1, AllocatableValue src2, int origin, AMD64SIMDInstructionEncoding encoding) {
379+
super(TYPE);
380+
AMD64Kind eKind = ((AMD64Kind) result.getPlatformKind()).getScalar();
381+
this.result = result;
382+
this.src1 = src1;
383+
this.src2 = src2;
384+
this.originInBytes = origin * eKind.getSizeInBytes();
385+
this.encoding = encoding;
386+
allocateTempIfNecessary(gen);
387+
}
388+
389+
@Override
390+
public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
391+
int resultSize = result.getPlatformKind().getSizeInBytes();
392+
switch (resultSize) {
393+
case 4 -> {
394+
if (src1.equals(src2) && originInBytes == 2) {
395+
VexRMIOp.VPSHUFLW.encoding(encoding).emit(masm, XMM, asRegister(result), asRegister(src1), 0x1);
396+
} else {
397+
VexRMIOp.VPSHUFD.encoding(encoding).emit(masm, XMM, asRegister(tmp1), asRegister(src1), 0);
398+
VexRVMIOp.VPALIGNR.encoding(encoding).emit(masm, XMM, asRegister(result), asRegister(src2), asRegister(tmp1), originInBytes + 12);
399+
}
400+
}
401+
case 8 -> {
402+
if (src1.equals(src2) && originInBytes % 2 == 0) {
403+
int imm;
404+
if (originInBytes == 2) {
405+
imm = 0b00111001;
406+
} else if (originInBytes == 4) {
407+
imm = 0b01001110;
408+
} else {
409+
GraalError.guarantee(originInBytes == 6, "unexpected originInBytes %d", originInBytes);
410+
imm = 0b10010011;
411+
}
412+
VexRMIOp.VPSHUFLW.encoding(encoding).emit(masm, XMM, asRegister(result), asRegister(src1), imm);
413+
} else {
414+
VexRMIOp.VPSHUFD.encoding(encoding).emit(masm, XMM, asRegister(tmp1), asRegister(src1), 0x40);
415+
VexRVMIOp.VPALIGNR.encoding(encoding).emit(masm, XMM, asRegister(result), asRegister(src2), asRegister(tmp1), originInBytes + 8);
416+
}
417+
}
418+
case 16 -> VexRVMIOp.VPALIGNR.encoding(encoding).emit(masm, XMM, asRegister(result), asRegister(src2), asRegister(src1), originInBytes);
419+
case 32 -> {
420+
if (encoding == AMD64SIMDInstructionEncoding.VEX || originInBytes % Integer.BYTES != 0) {
421+
Register tmp = originInBytes == 16 ? asRegister(result) : asRegister(tmp1);
422+
if (encoding == AMD64SIMDInstructionEncoding.VEX) {
423+
VexRVMIOp.VPERM2I128.emit(masm, YMM, tmp, asRegister(src1), asRegister(src2), 0x21);
424+
} else {
425+
VexRVMIOp.EVALIGND.emit(masm, YMM, tmp, asRegister(src2), asRegister(src1), 4);
426+
}
427+
if (originInBytes < 16) {
428+
VexRVMIOp.VPALIGNR.encoding(encoding).emit(masm, YMM, asRegister(result), asRegister(tmp1), asRegister(src1), originInBytes);
429+
} else if (originInBytes > 16) {
430+
VexRVMIOp.VPALIGNR.encoding(encoding).emit(masm, YMM, asRegister(result), asRegister(src2), asRegister(tmp1), originInBytes - 16);
431+
}
432+
} else {
433+
VexRVMIOp.EVALIGND.emit(masm, YMM, asRegister(result), asRegister(src2), asRegister(src1), originInBytes / Integer.BYTES);
434+
}
435+
}
436+
case 64 -> {
437+
GraalError.guarantee(encoding == AMD64SIMDInstructionEncoding.EVEX, "unexpected encoding with 512-bit vector");
438+
if (originInBytes % 4 != 0) {
439+
if (originInBytes < 16) {
440+
VexRVMIOp.EVALIGND.emit(masm, ZMM, asRegister(tmp1), asRegister(src2), asRegister(src1), 4);
441+
VexRVMIOp.EVPALIGNR.emit(masm, ZMM, asRegister(result), asRegister(tmp1), asRegister(src1), originInBytes);
442+
} else if (originInBytes < 32) {
443+
VexRVMIOp.EVALIGND.emit(masm, ZMM, asRegister(tmp1), asRegister(src2), asRegister(src1), 4);
444+
VexRVMIOp.EVALIGND.emit(masm, ZMM, asRegister(tmp2), asRegister(src2), asRegister(src1), 8);
445+
VexRVMIOp.EVPALIGNR.emit(masm, ZMM, asRegister(result), asRegister(tmp2), asRegister(tmp1), originInBytes - 16);
446+
} else if (originInBytes < 48) {
447+
VexRVMIOp.EVALIGND.emit(masm, ZMM, asRegister(tmp1), asRegister(src2), asRegister(src1), 8);
448+
VexRVMIOp.EVALIGND.emit(masm, ZMM, asRegister(tmp2), asRegister(src2), asRegister(src1), 12);
449+
VexRVMIOp.EVPALIGNR.emit(masm, ZMM, asRegister(result), asRegister(tmp2), asRegister(tmp1), originInBytes - 32);
450+
} else {
451+
VexRVMIOp.EVALIGND.emit(masm, ZMM, asRegister(tmp1), asRegister(src2), asRegister(src1), 12);
452+
VexRVMIOp.EVPALIGNR.emit(masm, ZMM, asRegister(result), asRegister(src2), asRegister(tmp1), originInBytes - 48);
453+
}
454+
} else {
455+
VexRVMIOp.EVALIGND.emit(masm, ZMM, asRegister(result), asRegister(src2), asRegister(src1), originInBytes / Integer.BYTES);
456+
}
457+
}
458+
default -> GraalError.shouldNotReachHereUnexpectedValue(resultSize);
459+
}
460+
}
461+
462+
private void allocateTempIfNecessary(AMD64LIRGenerator gen) {
463+
PlatformKind resultKind = result.getPlatformKind();
464+
boolean needsTemp;
465+
if (resultKind.getSizeInBytes() < XMM.getBytes()) {
466+
needsTemp = !src1.equals(src2) || originInBytes % 2 != 0;
467+
} else if (resultKind.getSizeInBytes() == XMM.getBytes()) {
468+
needsTemp = false;
469+
} else if (encoding == AMD64SIMDInstructionEncoding.VEX) {
470+
needsTemp = true;
471+
} else {
472+
needsTemp = (originInBytes % Integer.BYTES != 0);
473+
}
474+
if (needsTemp) {
475+
tmp1 = gen.newVariable(LIRKind.value(resultKind));
476+
} else {
477+
tmp1 = Value.ILLEGAL;
478+
}
479+
480+
if (resultKind.getSizeInBytes() == ZMM.getBytes() && originInBytes % Integer.BYTES != 0 &&
481+
originInBytes > 16 && originInBytes < 48) {
482+
GraalError.guarantee(!tmp1.equals(Value.ILLEGAL), "must have tmp1 with originInBytes = %d", originInBytes);
483+
tmp2 = gen.newVariable(LIRKind.value(resultKind));
484+
} else {
485+
tmp2 = Value.ILLEGAL;
486+
}
487+
}
488+
}
489+
363490
public static final class IntToVectorOp extends AMD64LIRInstruction {
364491
public static final LIRInstructionClass<IntToVectorOp> TYPE = LIRInstructionClass.create(IntToVectorOp.class);
365492

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/vector/lir/amd64/AMD64VectorArithmeticLIRGenerator.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,4 +510,14 @@ public Value emitVectorPermute(LIRKind resultKind, Value source, Value indices)
510510
getLIRGen().append(AMD64VectorShuffle.PermuteOp.create(getAMD64LIRGen(), result, asAllocatable(source), asAllocatable(indices), getSimdEncoding()));
511511
return result;
512512
}
513+
514+
/**
515+
* Do a slice operation, see
516+
* {@code jdk.incubator.vector.Vector<E>::slice(int, jdk.incubator.vector.Vector<E>)}.
517+
*/
518+
public Value emitVectorSlice(LIRKind resultKind, Value src1, Value src2, int origin) {
519+
Variable result = getLIRGen().newVariable(resultKind);
520+
getLIRGen().append(new AMD64VectorShuffle.SliceOp(getAMD64LIRGen(), result, asAllocatable(src1), asAllocatable(src2), origin, getSimdEncoding()));
521+
return result;
522+
}
513523
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation. Oracle designates this
8+
* particular file as subject to the "Classpath" exception as provided
9+
* by Oracle in the LICENSE file that accompanied this code.
10+
*
11+
* This code is distributed in the hope that it will be useful, but WITHOUT
12+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14+
* version 2 for more details (a copy is included in the LICENSE file that
15+
* accompanied this code).
16+
*
17+
* You should have received a copy of the GNU General Public License version
18+
* 2 along with this work; if not, write to the Free Software Foundation,
19+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20+
*
21+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22+
* or visit www.oracle.com if you need additional information or have any
23+
* questions.
24+
*/
25+
package jdk.graal.compiler.vector.nodes.amd64;
26+
27+
import jdk.graal.compiler.core.common.LIRKind;
28+
import jdk.graal.compiler.debug.GraalError;
29+
import jdk.graal.compiler.graph.NodeClass;
30+
import jdk.graal.compiler.nodeinfo.NodeCycles;
31+
import jdk.graal.compiler.nodeinfo.NodeInfo;
32+
import jdk.graal.compiler.nodeinfo.NodeSize;
33+
import jdk.graal.compiler.nodes.NodeView;
34+
import jdk.graal.compiler.nodes.ValueNode;
35+
import jdk.graal.compiler.nodes.calc.FloatingNode;
36+
import jdk.graal.compiler.nodes.spi.NodeLIRBuilderTool;
37+
import jdk.graal.compiler.vector.lir.VectorLIRGeneratorTool;
38+
import jdk.graal.compiler.vector.lir.VectorLIRLowerable;
39+
import jdk.graal.compiler.vector.lir.amd64.AMD64VectorArithmeticLIRGenerator;
40+
import jdk.graal.compiler.vector.nodes.simd.SimdStamp;
41+
42+
/**
43+
* A slice operation concatenates its inputs into a sequence of {@code VLENGTH * 2} elements, then
44+
* {@code VLENGTH} elements are collected starting at index {@link #origin} to form the result. If
45+
* the 2 inputs are the same, then this operation is the same as rotating the input left by
46+
* {@link #origin} elements.
47+
*/
48+
@NodeInfo(cycles = NodeCycles.CYCLES_1, size = NodeSize.SIZE_1)
49+
public class AMD64SimdSliceNode extends FloatingNode implements VectorLIRLowerable {
50+
51+
public static final NodeClass<AMD64SimdSliceNode> TYPE = NodeClass.create(AMD64SimdSliceNode.class);
52+
53+
@Input protected ValueNode src1;
54+
@Input protected ValueNode src2;
55+
private final int origin;
56+
57+
protected AMD64SimdSliceNode(SimdStamp stamp, ValueNode src1, ValueNode src2, int origin) {
58+
super(TYPE, stamp);
59+
this.src1 = src1;
60+
this.src2 = src2;
61+
this.origin = origin;
62+
}
63+
64+
public static AMD64SimdSliceNode create(ValueNode src1, ValueNode src2, int origin) {
65+
GraalError.guarantee(src1.stamp(NodeView.DEFAULT) instanceof SimdStamp, "unexpected input stamp %s", src1);
66+
SimdStamp stamp = (SimdStamp) src1.stamp(NodeView.DEFAULT).unrestricted();
67+
GraalError.guarantee(stamp.isCompatible(src2.stamp(NodeView.DEFAULT)), "unexpected input stamps: %s, %s", src1, src2);
68+
GraalError.guarantee(origin > 0 && origin < stamp.getVectorLength(), "unexpected origin %d of vector input %s", origin, src1);
69+
return new AMD64SimdSliceNode(stamp, src1, src2, origin);
70+
}
71+
72+
public ValueNode getSrc1() {
73+
return src1;
74+
}
75+
76+
public ValueNode getSrc2() {
77+
return src2;
78+
}
79+
80+
public int getOrigin() {
81+
return origin;
82+
}
83+
84+
@Override
85+
public void generate(NodeLIRBuilderTool builder, VectorLIRGeneratorTool gen) {
86+
LIRKind resultKind = builder.getLIRGeneratorTool().getLIRKind(stamp);
87+
builder.setResult(this, ((AMD64VectorArithmeticLIRGenerator) gen).emitVectorSlice(resultKind, builder.operand(src1), builder.operand(src2), origin));
88+
}
89+
}

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/vector/nodes/simd/SimdBlendWithConstantMaskNode.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ public Stamp foldStamp(Stamp falseStamp, Stamp trueStamp) {
8989

9090
@Override
9191
public Node canonical(CanonicalizerTool tool, ValueNode falseValues, ValueNode trueValues) {
92+
if (falseValues == trueValues) {
93+
return falseValues;
94+
}
95+
9296
boolean allSecond = true;
9397
boolean allFirst = true;
9498
for (int i = 0; (allSecond || allFirst) && i < selector.length; ++i) {

compiler/src/jdk.graal.compiler/src/jdk/graal/compiler/vector/nodes/simd/SimdBlendWithLogicMaskNode.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ public Stamp foldStamp(Stamp stampX, Stamp stampY, Stamp stampZ) {
8686

8787
@Override
8888
public Node canonical(CanonicalizerTool tool, ValueNode forX, ValueNode forY, ValueNode forZ) {
89+
if (forX == forY) {
90+
return forX;
91+
}
92+
8993
ValueNode mask = forZ;
9094
Stamp toStamp = mask.stamp(NodeView.DEFAULT);
9195
while (mask instanceof ReinterpretNode simdMask && SimdStamp.isOpmask(simdMask.stamp(NodeView.DEFAULT))) {

0 commit comments

Comments
 (0)