Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8244720: Check MethodType and FunctionDescritpor used when linking #158

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -130,15 +130,7 @@ public long offset() {
}

public VarHandle dereferenceHandle(Class<?> carrier) {
if (!(layout instanceof ValueLayout)) {
throw badLayoutPath("layout path does not select a value layout");
}

if (!carrier.isPrimitive() || carrier == void.class || carrier == boolean.class // illegal carrier?
|| Wrapper.forPrimitiveType(carrier).bitWidth() != layout.bitSize()) { // carrier has the right size?
throw new IllegalArgumentException("Invalid carrier: " + carrier + ", for layout " + layout);
}

Utils.checkPrimitiveCarrierCompat(carrier, layout);
checkAlignment(this);

return Utils.fixUpVarHandle(JLI.memoryAccessVarHandle(
Expand Down
Expand Up @@ -26,13 +26,16 @@

package jdk.internal.foreign;

import jdk.incubator.foreign.GroupLayout;
import jdk.incubator.foreign.MemoryAddress;
import jdk.incubator.foreign.MemoryHandles;
import jdk.incubator.foreign.MemoryLayout;
import jdk.incubator.foreign.SystemABI;
import jdk.incubator.foreign.ValueLayout;
import jdk.internal.access.foreign.MemoryAddressProxy;
import jdk.internal.foreign.abi.SharedUtils;
import jdk.internal.misc.VM;
import sun.invoke.util.Wrapper;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
Expand Down Expand Up @@ -112,4 +115,27 @@ public static <Z extends MemoryLayout> Z pick(Z sysv, Z win64, Z aarch64) {
default -> throw new ExceptionInInitializerError("Unexpected ABI: " + abi.name());
};
}

public static void checkPrimitiveCarrierCompat(Class<?> carrier, MemoryLayout layout) {
checkLayoutType(layout, ValueLayout.class);
if (!isValidPrimitiveCarrier(carrier))
throw new IllegalArgumentException("Unsupported carrier: " + carrier);
if (Wrapper.forPrimitiveType(carrier).bitWidth() != layout.bitSize())
throw new IllegalArgumentException("Carrier size mismatch: " + carrier + " != " + layout);
}

public static boolean isValidPrimitiveCarrier(Class<?> carrier) {
return carrier == byte.class
|| carrier == short.class
|| carrier == char.class
|| carrier == int.class
|| carrier == long.class
|| carrier == float.class
|| carrier == double.class;
}

public static void checkLayoutType(MemoryLayout layout, Class<? extends MemoryLayout> layoutType) {
if (!layoutType.isInstance(layout))
throw new IllegalArgumentException("Expected a " + layoutType.getSimpleName() + ": " + layout);
}
}
Expand Up @@ -25,23 +25,23 @@
package jdk.internal.foreign.abi;

import jdk.incubator.foreign.FunctionDescriptor;
import jdk.incubator.foreign.GroupLayout;
import jdk.incubator.foreign.MemoryAddress;
import jdk.incubator.foreign.MemoryLayout;
import jdk.incubator.foreign.MemorySegment;
import jdk.incubator.foreign.SequenceLayout;
import jdk.incubator.foreign.SystemABI;
import jdk.incubator.foreign.ValueLayout;
import jdk.internal.foreign.MemoryAddressImpl;
import jdk.internal.foreign.Utils;

import jdk.incubator.foreign.GroupLayout;
import jdk.incubator.foreign.MemoryLayout;
import jdk.incubator.foreign.SequenceLayout;
import jdk.incubator.foreign.ValueLayout;
import jdk.internal.foreign.abi.aarch64.AArch64ABI;
import jdk.internal.foreign.abi.x64.sysv.SysVx64ABI;
import jdk.internal.foreign.abi.x64.windows.Windowsx64ABI;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.List;
import java.util.stream.IntStream;

import static java.lang.invoke.MethodHandles.collectArguments;
Expand Down Expand Up @@ -186,11 +186,32 @@ private static MemoryAddress bufferCopy(MemoryAddress dest, MemorySegment buffer
return dest;
}

public static void checkFunctionTypes(MethodType mt, FunctionDescriptor cDesc) {
if (mt.parameterCount() != cDesc.argumentLayouts().size())
throw new IllegalArgumentException("arity must match!");
if ((mt.returnType() == void.class) == cDesc.returnLayout().isPresent())
throw new IllegalArgumentException("return type presence must match!");
private static void checkCompatibleType(Class<?> carrier, MemoryLayout layout, long addressSize) {
JornVernee marked this conversation as resolved.
Show resolved Hide resolved
if (carrier.isPrimitive()) {
Utils.checkPrimitiveCarrierCompat(carrier, layout);
} else if (carrier == MemoryAddress.class) {
Utils.checkLayoutType(layout, ValueLayout.class);
if (layout.bitSize() != addressSize)
throw new IllegalArgumentException("Address size mismatch: " + addressSize + " != " + layout.bitSize());
} else if(carrier == MemorySegment.class) {
Utils.checkLayoutType(layout, GroupLayout.class);
} else {
throw new IllegalArgumentException("Unsupported carrier: " + carrier);
}
}

public static void checkFunctionTypes(MethodType mt, FunctionDescriptor cDesc, long addressSize) {
if (mt.returnType() == void.class != cDesc.returnLayout().isEmpty())
throw new IllegalArgumentException("Return type mismatch: " + mt + " != " + cDesc);
List<MemoryLayout> argLayouts = cDesc.argumentLayouts();
if (mt.parameterCount() != argLayouts.size())
throw new IllegalArgumentException("Arity mismatch: " + mt + " != " + cDesc);

int paramCount = mt.parameterCount();
for (int i = 0; i < paramCount; i++) {
checkCompatibleType(mt.parameterType(i), argLayouts.get(i), addressSize);
}
cDesc.returnLayout().ifPresent(rl -> checkCompatibleType(mt.returnType(), rl, addressSize));
}

public static Class<?> primitiveCarrierForSize(long size) {
Expand Down
Expand Up @@ -42,6 +42,8 @@
public class AArch64ABI implements SystemABI {
private static AArch64ABI instance;

static final long ADDRESS_SIZE = 64; // bits

public static AArch64ABI getInstance() {
if (instance == null) {
instance = new AArch64ABI();
Expand Down
Expand Up @@ -98,7 +98,7 @@ public static class Bindings {
}

public static Bindings getBindings(MethodType mt, FunctionDescriptor cDesc, boolean forUpcall) {
SharedUtils.checkFunctionTypes(mt, cDesc);
SharedUtils.checkFunctionTypes(mt, cDesc, AArch64ABI.ADDRESS_SIZE);

CallingSequenceBuilder csb = new CallingSequenceBuilder(forUpcall);

Expand Down
Expand Up @@ -93,7 +93,7 @@ public static class Bindings {
}

public static Bindings getBindings(MethodType mt, FunctionDescriptor cDesc, boolean forUpcall) {
SharedUtils.checkFunctionTypes(mt, cDesc);
SharedUtils.checkFunctionTypes(mt, cDesc, SysVx64ABI.ADDRESS_SIZE);

CallingSequenceBuilder csb = new CallingSequenceBuilder(forUpcall);

Expand Down
Expand Up @@ -47,6 +47,8 @@ public class SysVx64ABI implements SystemABI {

private static SysVx64ABI instance;

static final long ADDRESS_SIZE = 64; // bits

public static SysVx64ABI getInstance() {
if (instance == null) {
instance = new SysVx64ABI();
Expand Down
Expand Up @@ -83,7 +83,7 @@ public static class Bindings {
}

public static Bindings getBindings(MethodType mt, FunctionDescriptor cDesc, boolean forUpcall) {
SharedUtils.checkFunctionTypes(mt, cDesc);
SharedUtils.checkFunctionTypes(mt, cDesc, Windowsx64ABI.ADDRESS_SIZE);

class CallingSequenceBuilderHelper {
final CallingSequenceBuilder csb = new CallingSequenceBuilder(forUpcall);
Expand Down
Expand Up @@ -49,6 +49,8 @@ public class Windowsx64ABI implements SystemABI {

private static Windowsx64ABI instance;

static final long ADDRESS_SIZE = 64; // bits

public static Windowsx64ABI getInstance() {
if (instance == null) {
instance = new Windowsx64ABI();
Expand Down
4 changes: 2 additions & 2 deletions test/jdk/java/foreign/StdLibTest.java
Expand Up @@ -200,7 +200,7 @@ static class StdLibHelper {

qsort = abi.downcallHandle(lookup.lookup("qsort"),
MethodType.methodType(void.class, MemoryAddress.class, long.class, long.class, MemoryAddress.class),
FunctionDescriptor.ofVoid(C_POINTER, C_LONG, C_LONG, C_POINTER));
FunctionDescriptor.ofVoid(C_POINTER, C_LONGLONG, C_LONGLONG, C_POINTER));

//qsort upcall handle
qsortCompar = MethodHandles.lookup().findStatic(StdLibTest.StdLibHelper.class, "qsortCompare",
Expand Down Expand Up @@ -404,7 +404,7 @@ public static Object[][] printfArgs() {
enum PrintfArg {
INTEGRAL(int.class, asVarArg(C_INT), "%d", 42, 42),
STRING(MemoryAddress.class, asVarArg(C_POINTER), "%s", toCString("str").baseAddress(), "str"),
CHAR(char.class, asVarArg(C_CHAR), "%c", 'h', 'h'),
CHAR(byte.class, asVarArg(C_CHAR), "%c", (byte) 'h', 'h'),
DOUBLE(double.class, asVarArg(C_DOUBLE), "%.4f", 1.2345d, 1.2345d);

final Class<?> carrier;
Expand Down
112 changes: 112 additions & 0 deletions test/jdk/java/foreign/TestIllegalLink.java
@@ -0,0 +1,112 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*
*/

/*
* @test
*
* @run testng/othervm -Dforeign.restricted=permit TestIllegalLink
*/

import jdk.incubator.foreign.FunctionDescriptor;
import jdk.incubator.foreign.MemoryAddress;
import jdk.incubator.foreign.MemoryLayout;
import jdk.incubator.foreign.MemoryLayouts;
import jdk.incubator.foreign.MemorySegment;
import jdk.incubator.foreign.SystemABI;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.lang.invoke.MethodType;

import static jdk.incubator.foreign.SystemABI.C_INT;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;

public class TestIllegalLink {

private static final MemoryAddress dummyTarget = MemoryAddress.NULL;
private static final SystemABI ABI = SystemABI.getSystemABI();

@Test(dataProvider = "types")
public void testTypeMismatch(MethodType mt, FunctionDescriptor desc, String expectedExceptionMessage) {
try {
ABI.downcallHandle(dummyTarget, mt, desc);
fail("Expected IllegalArgumentException was not thrown");
} catch (IllegalArgumentException e) {
assertTrue(e.getMessage().contains(expectedExceptionMessage));
}
}

@DataProvider
public static Object[][] types() {
return new Object[][]{
{
MethodType.methodType(void.class),
FunctionDescriptor.of(C_INT),
"Return type mismatch"
},
{
MethodType.methodType(void.class),
FunctionDescriptor.ofVoid(C_INT),
"Arity mismatch"
},
{
MethodType.methodType(void.class, int.class),
FunctionDescriptor.ofVoid(MemoryLayout.ofPaddingBits(32)),
"Expected a ValueLayout"
},
{
MethodType.methodType(void.class, boolean.class),
FunctionDescriptor.ofVoid(MemoryLayouts.BITS_8_LE),
"Unsupported carrier"
},
{
MethodType.methodType(void.class, int.class),
FunctionDescriptor.ofVoid(MemoryLayouts.BITS_64_LE),
"Carrier size mismatch"
},
{
MethodType.methodType(void.class, MemoryAddress.class),
FunctionDescriptor.ofVoid(MemoryLayout.ofPaddingBits(64)),
"Expected a ValueLayout"
},
{
MethodType.methodType(void.class, MemoryAddress.class),
FunctionDescriptor.ofVoid(MemoryLayouts.BITS_16_LE),
"Address size mismatch"
},
{
MethodType.methodType(void.class, MemorySegment.class),
FunctionDescriptor.ofVoid(MemoryLayouts.BITS_64_LE),
"Expected a GroupLayout"
},
{
MethodType.methodType(void.class, String.class),
FunctionDescriptor.ofVoid(MemoryLayouts.BITS_64_LE),
"Unsupported carrier"
},
};
}

}