Skip to content

Commit

Permalink
8303516: HFAs with nested structs/unions/arrays not handled correctly…
Browse files Browse the repository at this point in the history
… on AArch64

Reviewed-by: mcimadamore
  • Loading branch information
JornVernee committed Mar 3, 2023
1 parent cbdc7a6 commit c6de66c
Show file tree
Hide file tree
Showing 8 changed files with 872 additions and 260 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -158,25 +158,29 @@ private static MemorySegment bufferCopy(MemorySegment dest, MemorySegment buffer
}

public static Class<?> primitiveCarrierForSize(long size, boolean useFloat) {
return primitiveLayoutForSize(size, useFloat).carrier();
}

public static ValueLayout primitiveLayoutForSize(long size, boolean useFloat) {
if (useFloat) {
if (size == 4) {
return float.class;
return JAVA_FLOAT;
} else if (size == 8) {
return double.class;
return JAVA_DOUBLE;
}
} else {
if (size == 1) {
return byte.class;
return JAVA_BYTE;
} else if (size == 2) {
return short.class;
return JAVA_SHORT;
} else if (size <= 4) {
return int.class;
return JAVA_INT;
} else if (size <= 8) {
return long.class;
return JAVA_LONG;
}
}

throw new IllegalArgumentException("No type for size: " + size + " isFloat=" + useFloat);
throw new IllegalArgumentException("No layout for size: " + size + " isFloat=" + useFloat);
}

public static Linker getSystemLinker() {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2020, 2021, Arm Limited. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
Expand Down Expand Up @@ -28,7 +28,10 @@
import java.lang.foreign.GroupLayout;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.ValueLayout;
import java.util.List;
import java.util.ArrayList;

public enum TypeClass {
STRUCT_REGISTER,
Expand Down Expand Up @@ -58,15 +61,38 @@ static boolean isRegisterAggregate(MemoryLayout type) {
return type.bitSize() <= MAX_AGGREGATE_REGS_SIZE * 64;
}

static List<MemoryLayout> scalarLayouts(GroupLayout gl) {
List<MemoryLayout> out = new ArrayList<>();
scalarLayoutsInternal(out, gl);
return out;
}

private static void scalarLayoutsInternal(List<MemoryLayout> out, GroupLayout gl) {
for (MemoryLayout member : gl.memberLayouts()) {
if (member instanceof GroupLayout memberGl) {
scalarLayoutsInternal(out, memberGl);
} else if (member instanceof SequenceLayout memberSl) {
for (long i = 0; i < memberSl.elementCount(); i++) {
out.add(memberSl.elementLayout());
}
} else {
// padding or value layouts
out.add(member);
}
}
}

static boolean isHomogeneousFloatAggregate(MemoryLayout type) {
if (!(type instanceof GroupLayout groupLayout))
return false;

final int numElements = groupLayout.memberLayouts().size();
List<MemoryLayout> scalarLayouts = scalarLayouts(groupLayout);

final int numElements = scalarLayouts.size();
if (numElements > 4 || numElements == 0)
return false;

MemoryLayout baseType = groupLayout.memberLayouts().get(0);
MemoryLayout baseType = scalarLayouts.get(0);

if (!(baseType instanceof ValueLayout))
return false;
Expand All @@ -75,7 +101,7 @@ static boolean isHomogeneousFloatAggregate(MemoryLayout type) {
if (baseArgClass != FLOAT)
return false;

for (MemoryLayout elem : groupLayout.memberLayouts()) {
for (MemoryLayout elem : scalarLayouts) {
if (!(elem instanceof ValueLayout))
return false;

Expand Down
154 changes: 153 additions & 1 deletion test/jdk/java/foreign/NativeTestHelper.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand All @@ -22,21 +22,50 @@
*
*/

import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.GroupLayout;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.PaddingLayout;
import java.lang.foreign.SegmentAllocator;
import java.lang.foreign.SegmentScope;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.StructLayout;
import java.lang.foreign.SymbolLookup;
import java.lang.foreign.UnionLayout;
import java.lang.foreign.ValueLayout;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.invoke.VarHandle;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.UnaryOperator;
import java.util.random.RandomGenerator;

import static java.lang.foreign.MemoryLayout.PathElement.groupElement;
import static java.lang.foreign.MemoryLayout.PathElement.sequenceElement;

public class NativeTestHelper {

public static final boolean IS_WINDOWS = System.getProperty("os.name").startsWith("Windows");

private static final MethodHandle MH_SAVER;

static {
try {
MH_SAVER = MethodHandles.lookup().findStatic(NativeTestHelper.class, "saver",
MethodType.methodType(Object.class, Object[].class, List.class, AtomicReference.class, SegmentAllocator.class, int.class));
} catch (ReflectiveOperationException e) {
throw new ExceptionInInitializerError(e);
}
}

public static boolean isIntegral(MemoryLayout layout) {
return layout instanceof ValueLayout valueLayout && isIntegral(valueLayout.carrier());
}
Expand Down Expand Up @@ -126,4 +155,127 @@ public static MemorySegment upcallStub(Class<?> holder, String name, FunctionDes
throw new RuntimeException(e);
}
}

public record TestValue (Object value, Consumer<Object> check) {}

public static TestValue genTestValue(RandomGenerator random, MemoryLayout layout, SegmentAllocator allocator) {
if (layout instanceof StructLayout struct) {
MemorySegment segment = allocator.allocate(struct);
List<Consumer<Object>> fieldChecks = new ArrayList<>();
for (MemoryLayout fieldLayout : struct.memberLayouts()) {
if (fieldLayout instanceof PaddingLayout) continue;
MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow());
fieldChecks.add(initField(random, segment, struct, fieldLayout, fieldPath, allocator));
}
return new TestValue(segment, actual -> fieldChecks.forEach(check -> check.accept(actual)));
} else if (layout instanceof UnionLayout union) {
MemorySegment segment = allocator.allocate(union);
List<MemoryLayout> filteredFields = union.memberLayouts().stream()
.filter(l -> !(l instanceof PaddingLayout))
.toList();
int fieldIdx = random.nextInt(filteredFields.size());
MemoryLayout fieldLayout = filteredFields.get(fieldIdx);
MemoryLayout.PathElement fieldPath = groupElement(fieldLayout.name().orElseThrow());
Consumer<Object> check = initField(random, segment, union, fieldLayout, fieldPath, allocator);
return new TestValue(segment, check);
} else if (layout instanceof SequenceLayout array) {
MemorySegment segment = allocator.allocate(array);
List<Consumer<Object>> elementChecks = new ArrayList<>();
for (int i = 0; i < array.elementCount(); i++) {
elementChecks.add(initField(random, segment, array, array.elementLayout(), sequenceElement(i), allocator));
}
return new TestValue(segment, actual -> elementChecks.forEach(check -> check.accept(actual)));
} else if (layout instanceof ValueLayout.OfAddress) {
MemorySegment value = MemorySegment.ofAddress(random.nextLong());
return new TestValue(value, actual -> assertEquals(actual, value));
}else if (layout instanceof ValueLayout.OfByte) {
byte value = (byte) random.nextInt();
return new TestValue(value, actual -> assertEquals(actual, value));
} else if (layout instanceof ValueLayout.OfShort) {
short value = (short) random.nextInt();
return new TestValue(value, actual -> assertEquals(actual, value));
} else if (layout instanceof ValueLayout.OfInt) {
int value = random.nextInt();
return new TestValue(value, actual -> assertEquals(actual, value));
} else if (layout instanceof ValueLayout.OfLong) {
long value = random.nextLong();
return new TestValue(value, actual -> assertEquals(actual, value));
} else if (layout instanceof ValueLayout.OfFloat) {
float value = random.nextFloat();
return new TestValue(value, actual -> assertEquals(actual, value));
} else if (layout instanceof ValueLayout.OfDouble) {
double value = random.nextDouble();
return new TestValue(value, actual -> assertEquals(actual, value));
}

throw new IllegalStateException("Unexpected layout: " + layout);
}

private static Consumer<Object> initField(RandomGenerator random, MemorySegment container, MemoryLayout containerLayout,
MemoryLayout fieldLayout, MemoryLayout.PathElement fieldPath,
SegmentAllocator allocator) {
TestValue fieldValue = genTestValue(random, fieldLayout, allocator);
Consumer<Object> fieldCheck = fieldValue.check();
if (fieldLayout instanceof GroupLayout || fieldLayout instanceof SequenceLayout) {
UnaryOperator<MemorySegment> slicer = slicer(containerLayout, fieldPath);
MemorySegment slice = slicer.apply(container);
slice.copyFrom((MemorySegment) fieldValue.value());
return actual -> fieldCheck.accept(slicer.apply((MemorySegment) actual));
} else {
VarHandle accessor = containerLayout.varHandle(fieldPath);
//set value
accessor.set(container, fieldValue.value());
return actual -> fieldCheck.accept(accessor.get((MemorySegment) actual));
}
}

private static UnaryOperator<MemorySegment> slicer(MemoryLayout containerLayout, MemoryLayout.PathElement fieldPath) {
MethodHandle slicer = containerLayout.sliceHandle(fieldPath);
return container -> {
try {
return (MemorySegment) slicer.invokeExact(container);
} catch (Throwable e) {
throw new IllegalStateException(e);
}
};
}

private static void assertEquals(Object actual, Object expected) {
if (actual.getClass() != expected.getClass()) {
throw new AssertionError("Type mismatch: " + actual.getClass() + " != " + expected.getClass());
}
if (!actual.equals(expected)) {
throw new AssertionError("Not equal: " + actual + " != " + expected);
}
}

/**
* Make an upcall stub that saves its arguments into the given 'ref' array
*
* @param fd function descriptor for the upcall stub
* @param capturedArgs box to save arguments in
* @param arena allocator for making copies of by-value structs
* @param retIdx the index of the argument to return
* @return return the upcall stub
*/
public static MemorySegment makeArgSaverCB(FunctionDescriptor fd, Arena arena,
AtomicReference<Object[]> capturedArgs, int retIdx) {
MethodHandle target = MethodHandles.insertArguments(MH_SAVER, 1, fd.argumentLayouts(), capturedArgs, arena, retIdx);
target = target.asCollector(Object[].class, fd.argumentLayouts().size());
target = target.asType(fd.toMethodType());
return LINKER.upcallStub(target, fd, arena.scope());
}

private static Object saver(Object[] o, List<MemoryLayout> argLayouts, AtomicReference<Object[]> ref, SegmentAllocator allocator, int retArg) {
for (int i = 0; i < o.length; i++) {
if (argLayouts.get(i) instanceof GroupLayout gl) {
MemorySegment ms = (MemorySegment) o[i];
MemorySegment copy = allocator.allocate(gl);
copy.copyFrom(ms);
o[i] = copy;
}
}
ref.set(o);
return retArg != -1 ? o[retArg] : null;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2020, 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -426,4 +426,54 @@ public void testVarArgsInRegs() {

checkReturnBindings(callingSequence, new Binding[]{});
}

@Test
public void testFloatArrayStruct() {
// should be classified as HFA
StructLayout S10 = MemoryLayout.structLayout(
MemoryLayout.sequenceLayout(4, C_DOUBLE)
);
MethodType mt = MethodType.methodType(MemorySegment.class, MemorySegment.class);
FunctionDescriptor fd = FunctionDescriptor.of(S10, S10);
FunctionDescriptor fdExpected = FunctionDescriptor.of(S10, ADDRESS, ADDRESS, S10); // uses return buffer
CallArranger.Bindings bindings = CallArranger.LINUX.getBindings(mt, fd, false);

assertFalse(bindings.isInMemoryReturn());
CallingSequence callingSequence = bindings.callingSequence();
assertEquals(callingSequence.callerMethodType(), mt.insertParameterTypes(0, MemorySegment.class, MemorySegment.class));
assertEquals(callingSequence.functionDesc(), fdExpected);

// This is identical to the non-variadic calling sequence
checkArgumentBindings(callingSequence, new Binding[][]{
{ unboxAddress(), vmStore(RETURN_BUFFER_STORAGE, long.class) },
{ unboxAddress(), vmStore(TARGET_ADDRESS_STORAGE, long.class) },
{ dup(),
bufferLoad(0, double.class),
vmStore(v0, double.class),
dup(),
bufferLoad(8, double.class),
vmStore(v1, double.class),
dup(),
bufferLoad(16, double.class),
vmStore(v2, double.class),
bufferLoad(24, double.class),
vmStore(v3, double.class) },
});

checkReturnBindings(callingSequence, new Binding[]{
allocate(S10),
dup(),
vmLoad(v0, double.class),
bufferStore(0, double.class),
dup(),
vmLoad(v1, double.class),
bufferStore(8, double.class),
dup(),
vmLoad(v2, double.class),
bufferStore(16, double.class),
dup(),
vmLoad(v3, double.class),
bufferStore(24, double.class),
});
}
}

1 comment on commit c6de66c

@openjdk-notifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.