Skip to content

Commit

Permalink
8278897: Alignment of heap segments is not enforced correctly
Browse files Browse the repository at this point in the history
Reviewed-by: jvernee
  • Loading branch information
mcimadamore committed Jan 5, 2022
1 parent 0f4807e commit 9d43d25
Show file tree
Hide file tree
Showing 20 changed files with 600 additions and 34 deletions.
Expand Up @@ -50,7 +50,7 @@ abstract class MemoryAccessVarHandleBase extends VarHandle {
this.alignmentMask = alignmentMask;
}

static IllegalStateException newIllegalStateExceptionForMisalignedAccess(long address) {
return new IllegalStateException("Misaligned access at address: " + address);
static IllegalArgumentException newIllegalArgumentExceptionForMisalignedAccess(long address) {
return new IllegalArgumentException("Misaligned access at address: " + address);
}
}
Expand Up @@ -106,7 +106,7 @@ final class MemoryAccessVarHandle$Type$Helper extends MemoryAccessVarHandleBase
static long offset(boolean skipAlignmentMaskCheck, MemorySegmentProxy bb, long offset, long alignmentMask) {
long address = offsetNoVMAlignCheck(skipAlignmentMaskCheck, bb, offset, alignmentMask);
if ((address & VM_ALIGN) != 0) {
throw MemoryAccessVarHandleBase.newIllegalStateExceptionForMisalignedAccess(address);
throw MemoryAccessVarHandleBase.newIllegalArgumentExceptionForMisalignedAccess(address);
}
return address;
}
Expand All @@ -115,14 +115,15 @@ final class MemoryAccessVarHandle$Type$Helper extends MemoryAccessVarHandleBase
static long offsetNoVMAlignCheck(boolean skipAlignmentMaskCheck, MemorySegmentProxy bb, long offset, long alignmentMask) {
long base = bb.unsafeGetOffset();
long address = base + offset;
long maxAlignMask = bb.maxAlignMask();
if (skipAlignmentMaskCheck) {
//note: the offset portion has already been aligned-checked, by construction
if ((base & alignmentMask) != 0) {
throw MemoryAccessVarHandleBase.newIllegalStateExceptionForMisalignedAccess(address);
if (((base | maxAlignMask) & alignmentMask) != 0) {
throw MemoryAccessVarHandleBase.newIllegalArgumentExceptionForMisalignedAccess(address);
}
} else {
if ((address & alignmentMask) != 0) {
throw MemoryAccessVarHandleBase.newIllegalStateExceptionForMisalignedAccess(address);
if (((address | maxAlignMask) & alignmentMask) != 0) {
throw MemoryAccessVarHandleBase.newIllegalArgumentExceptionForMisalignedAccess(address);
}
}
return address;
Expand Down
Expand Up @@ -44,6 +44,7 @@ public abstract class MemorySegmentProxy {
public abstract Object unsafeGetBase();
public abstract boolean isSmall();
public abstract ScopedMemoryAccess.Scope scope();
public abstract long maxAlignMask();

/* Helper functions for offset computations. These are required so that we can avoid issuing long opcodes
* (e.g. LMUL, LADD) when we're operating on 'small' segments (segments whose length can be expressed with an int).
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Expand Up @@ -124,8 +124,12 @@ public Spliterator<MemorySegment> spliterator(MemoryLayout elementLayout) {
if (elementLayout.byteSize() == 0) {
throw new IllegalArgumentException("Element layout size cannot be zero");
}
if (byteSize() % elementLayout.byteSize() != 0) {
throw new IllegalArgumentException("Segment size is no a multiple of layout size");
Utils.checkElementAlignment(elementLayout, "Element layout alignment greater than its size");
if (!isAlignedForElement(0, elementLayout)) {
throw new IllegalArgumentException("Incompatible alignment constraints");
}
if (!Utils.isAligned(byteSize(), elementLayout.byteSize())) {
throw new IllegalArgumentException("Segment size is not a multiple of layout size");
}
return new SegmentSplitter(elementLayout.byteSize(), byteSize() / elementLayout.byteSize(),
this);
Expand Down Expand Up @@ -383,8 +387,13 @@ private boolean isSet(int mask) {
return (this.mask & mask) != 0;
}

@ForceInline
public final boolean isAlignedForElement(long offset, MemoryLayout layout) {
return (((unsafeGetOffset() + offset) | maxAlignMask()) & (layout.byteAlignment() - 1)) == 0;
}

private int checkArraySize(String typeName, int elemSize) {
if (length % elemSize != 0) {
if (!Utils.isAligned(length, elemSize)) {
throw new IllegalStateException(String.format("Segment size is not a multiple of %d. Size: %d", elemSize, length));
}
long arraySize = length / elemSize;
Expand Down
Expand Up @@ -47,6 +47,11 @@ public abstract class HeapMemorySegmentImpl<H> extends AbstractMemorySegmentImpl
private static final Unsafe UNSAFE = Unsafe.getUnsafe();
private static final int BYTE_ARR_BASE = UNSAFE.arrayBaseOffset(byte[].class);

private static final long MAX_ALIGN_1 = 1;
private static final long MAX_ALIGN_2 = 2;
private static final long MAX_ALIGN_4 = 4;
private static final long MAX_ALIGN_8 = 8;

final long offset;
final H base;

Expand Down Expand Up @@ -100,6 +105,11 @@ public static MemorySegment fromArray(byte[] arr) {
long byteSize = (long)arr.length * Unsafe.ARRAY_BYTE_INDEX_SCALE;
return new OfByte(Unsafe.ARRAY_BYTE_BASE_OFFSET, arr, byteSize, defaultAccessModes(byteSize));
}

@Override
public long maxAlignMask() {
return MAX_ALIGN_1;
}
}

public static class OfChar extends HeapMemorySegmentImpl<char[]> {
Expand All @@ -123,6 +133,11 @@ public static MemorySegment fromArray(char[] arr) {
long byteSize = (long)arr.length * Unsafe.ARRAY_CHAR_INDEX_SCALE;
return new OfChar(Unsafe.ARRAY_CHAR_BASE_OFFSET, arr, byteSize, defaultAccessModes(byteSize));
}

@Override
public long maxAlignMask() {
return MAX_ALIGN_2;
}
}

public static class OfShort extends HeapMemorySegmentImpl<short[]> {
Expand All @@ -146,6 +161,11 @@ public static MemorySegment fromArray(short[] arr) {
long byteSize = (long)arr.length * Unsafe.ARRAY_SHORT_INDEX_SCALE;
return new OfShort(Unsafe.ARRAY_SHORT_BASE_OFFSET, arr, byteSize, defaultAccessModes(byteSize));
}

@Override
public long maxAlignMask() {
return MAX_ALIGN_2;
}
}

public static class OfInt extends HeapMemorySegmentImpl<int[]> {
Expand All @@ -169,6 +189,11 @@ public static MemorySegment fromArray(int[] arr) {
long byteSize = (long)arr.length * Unsafe.ARRAY_INT_INDEX_SCALE;
return new OfInt(Unsafe.ARRAY_INT_BASE_OFFSET, arr, byteSize, defaultAccessModes(byteSize));
}

@Override
public long maxAlignMask() {
return MAX_ALIGN_4;
}
}

public static class OfLong extends HeapMemorySegmentImpl<long[]> {
Expand All @@ -192,6 +217,11 @@ public static MemorySegment fromArray(long[] arr) {
long byteSize = (long)arr.length * Unsafe.ARRAY_LONG_INDEX_SCALE;
return new OfLong(Unsafe.ARRAY_LONG_BASE_OFFSET, arr, byteSize, defaultAccessModes(byteSize));
}

@Override
public long maxAlignMask() {
return MAX_ALIGN_8;
}
}

public static class OfFloat extends HeapMemorySegmentImpl<float[]> {
Expand All @@ -215,6 +245,11 @@ public static MemorySegment fromArray(float[] arr) {
long byteSize = (long)arr.length * Unsafe.ARRAY_FLOAT_INDEX_SCALE;
return new OfFloat(Unsafe.ARRAY_FLOAT_BASE_OFFSET, arr, byteSize, defaultAccessModes(byteSize));
}

@Override
public long maxAlignMask() {
return MAX_ALIGN_4;
}
}

public static class OfDouble extends HeapMemorySegmentImpl<double[]> {
Expand All @@ -238,6 +273,11 @@ public static MemorySegment fromArray(double[] arr) {
long byteSize = (long)arr.length * Unsafe.ARRAY_DOUBLE_INDEX_SCALE;
return new OfDouble(Unsafe.ARRAY_DOUBLE_BASE_OFFSET, arr, byteSize, defaultAccessModes(byteSize));
}

@Override
public long maxAlignMask() {
return MAX_ALIGN_8;
}
}

}
Expand Up @@ -282,11 +282,11 @@ private static IllegalArgumentException badLayoutPath(String cause) {
private static void checkAlignment(LayoutPath path) {
MemoryLayout layout = path.layout;
long alignment = layout.bitAlignment();
if (path.offset % alignment != 0) {
if (!Utils.isAligned(path.offset, alignment)) {
throw new UnsupportedOperationException("Invalid alignment requirements for layout " + layout);
}
for (long stride : path.strides) {
if (stride % alignment != 0) {
if (!Utils.isAligned(stride, alignment)) {
throw new UnsupportedOperationException("Alignment requirements for layout " + layout + " do not match stride " + stride);
}
}
Expand Down
Expand Up @@ -267,6 +267,7 @@ public void set(ValueLayout.OfAddress layout, long offset, Addressable value) {
@CallerSensitive
public char getAtIndex(ValueLayout.OfChar layout, long index) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
return NativeMemorySegmentImpl.EVERYTHING.get(layout, toRawLongValue() + (index * layout.byteSize()));
}

Expand All @@ -275,6 +276,7 @@ public char getAtIndex(ValueLayout.OfChar layout, long index) {
@CallerSensitive
public void setAtIndex(ValueLayout.OfChar layout, long index, char value) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
NativeMemorySegmentImpl.EVERYTHING.set(layout, toRawLongValue() + (index * layout.byteSize()), value);
}

Expand All @@ -283,6 +285,7 @@ public void setAtIndex(ValueLayout.OfChar layout, long index, char value) {
@CallerSensitive
public short getAtIndex(ValueLayout.OfShort layout, long index) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
return NativeMemorySegmentImpl.EVERYTHING.get(layout, toRawLongValue() + (index * layout.byteSize()));
}

Expand All @@ -291,6 +294,7 @@ public short getAtIndex(ValueLayout.OfShort layout, long index) {
@CallerSensitive
public void setAtIndex(ValueLayout.OfShort layout, long index, short value) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
NativeMemorySegmentImpl.EVERYTHING.set(layout, toRawLongValue() + (index * layout.byteSize()), value);
}

Expand All @@ -299,6 +303,7 @@ public void setAtIndex(ValueLayout.OfShort layout, long index, short value) {
@CallerSensitive
public int getAtIndex(ValueLayout.OfInt layout, long index) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
return NativeMemorySegmentImpl.EVERYTHING.get(layout, toRawLongValue() + (index * layout.byteSize()));
}

Expand All @@ -307,6 +312,7 @@ public int getAtIndex(ValueLayout.OfInt layout, long index) {
@CallerSensitive
public void setAtIndex(ValueLayout.OfInt layout, long index, int value) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
NativeMemorySegmentImpl.EVERYTHING.set(layout, toRawLongValue() + (index * layout.byteSize()), value);
}

Expand All @@ -315,6 +321,7 @@ public void setAtIndex(ValueLayout.OfInt layout, long index, int value) {
@CallerSensitive
public float getAtIndex(ValueLayout.OfFloat layout, long index) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
return NativeMemorySegmentImpl.EVERYTHING.get(layout, toRawLongValue() + (index * layout.byteSize()));
}

Expand All @@ -323,6 +330,7 @@ public float getAtIndex(ValueLayout.OfFloat layout, long index) {
@CallerSensitive
public void setAtIndex(ValueLayout.OfFloat layout, long index, float value) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
NativeMemorySegmentImpl.EVERYTHING.set(layout, toRawLongValue() + (index * layout.byteSize()), value);
}

Expand All @@ -331,6 +339,7 @@ public void setAtIndex(ValueLayout.OfFloat layout, long index, float value) {
@CallerSensitive
public long getAtIndex(ValueLayout.OfLong layout, long index) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
return NativeMemorySegmentImpl.EVERYTHING.get(layout, toRawLongValue() + (index * layout.byteSize()));
}

Expand All @@ -339,6 +348,7 @@ public long getAtIndex(ValueLayout.OfLong layout, long index) {
@CallerSensitive
public void setAtIndex(ValueLayout.OfLong layout, long index, long value) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
NativeMemorySegmentImpl.EVERYTHING.set(layout, toRawLongValue() + (index * layout.byteSize()), value);
}

Expand All @@ -347,6 +357,7 @@ public void setAtIndex(ValueLayout.OfLong layout, long index, long value) {
@CallerSensitive
public double getAtIndex(ValueLayout.OfDouble layout, long index) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
return NativeMemorySegmentImpl.EVERYTHING.get(layout, toRawLongValue() + (index * layout.byteSize()));
}

Expand All @@ -355,6 +366,7 @@ public double getAtIndex(ValueLayout.OfDouble layout, long index) {
@CallerSensitive
public void setAtIndex(ValueLayout.OfDouble layout, long index, double value) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
NativeMemorySegmentImpl.EVERYTHING.set(layout, toRawLongValue() + (index * layout.byteSize()), value);
}

Expand All @@ -363,6 +375,7 @@ public void setAtIndex(ValueLayout.OfDouble layout, long index, double value) {
@CallerSensitive
public MemoryAddress getAtIndex(ValueLayout.OfAddress layout, long index) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
return NativeMemorySegmentImpl.EVERYTHING.get(layout, toRawLongValue() + (index * layout.byteSize()));
}

Expand All @@ -371,6 +384,7 @@ public MemoryAddress getAtIndex(ValueLayout.OfAddress layout, long index) {
@CallerSensitive
public void setAtIndex(ValueLayout.OfAddress layout, long index, Addressable value) {
Reflection.ensureNativeAccess(Reflection.getCallerClass());
Utils.checkElementAlignment(layout, "Layout alignment greater than its size");
NativeMemorySegmentImpl.EVERYTHING.set(layout, toRawLongValue() + (index * layout.byteSize()), value.address());
}
}
Expand Up @@ -92,6 +92,11 @@ Object base() {
return null;
}

@Override
public long maxAlignMask() {
return 0;
}

// factories

public static MemorySegment makeNativeSegment(long bytesSize, long alignmentBytes, ResourceScopeImpl scope) {
Expand Down
Expand Up @@ -98,7 +98,7 @@ public static MemorySegment alignUp(MemorySegment ms, long alignment) {
}

public static long bitsToBytesOrThrow(long bits, Supplier<RuntimeException> exFactory) {
if (bits % 8 == 0) {
if (Utils.isAligned(bits, 8)) {
return bits / 8;
} else {
throw exFactory.get();
Expand Down Expand Up @@ -173,4 +173,16 @@ public static long scaleOffset(MemorySegment segment, long index, long size) {
// note: we know size is a small value (as it comes from ValueLayout::byteSize())
return MemorySegmentProxy.multiplyOffsets(index, (int)size, (AbstractMemorySegmentImpl)segment);
}

@ForceInline
public static boolean isAligned(long offset, long align) {
return (offset & (align - 1)) == 0;
}

@ForceInline
public static void checkElementAlignment(MemoryLayout layout, String msg) {
if (layout.byteAlignment() > layout.byteSize()) {
throw new IllegalArgumentException(msg);
}
}
}
12 changes: 12 additions & 0 deletions test/jdk/java/foreign/TestArrayCopy.java
Expand Up @@ -253,6 +253,18 @@ public void testCarrierMismatchDst() {
MemorySegment.copy(new byte[] { 1, 2, 3, 4 }, 0, segment, JAVA_INT, 0, 4);
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void testHyperAlignedSrc() {
MemorySegment segment = MemorySegment.ofArray(new byte[] {1, 2, 3, 4});
MemorySegment.copy(new byte[] { 1, 2, 3, 4 }, 0, segment, JAVA_BYTE.withBitAlignment(16), 0, 4);
}

@Test(expectedExceptions = IllegalArgumentException.class)
public void testHyperAlignedDst() {
MemorySegment segment = MemorySegment.ofArray(new byte[] {1, 2, 3, 4});
MemorySegment.copy(segment, JAVA_BYTE.withBitAlignment(16), 0, new byte[] { 1, 2, 3, 4 }, 0, 4);
}

/***** Utilities *****/

public static MemorySegment srcSegment(int bytesLength) {
Expand Down

0 comments on commit 9d43d25

Please sign in to comment.