Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8261906: Improve jextract support for virtual functions #456

Closed
wants to merge 13 commits into from
Closed
Changes from 3 commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
@@ -66,24 +66,24 @@ public Constant addLayout(String javaName, MemoryLayout layout) {
() -> emitLayoutField(javaName, layout));
}

public Constant addFieldVarHandle(String javaName, String nativeName, MemoryLayout layout,
Class<?> type, String rootJavaName, List<String> prefixElementNames) {
return addVarHandle(javaName, nativeName, layout, type, rootJavaName, prefixElementNames);
public Constant addFieldVarHandle(String javaName, String nativeName, VarInfo varInfo,
String rootJavaName, List<String> prefixElementNames) {
return addVarHandle(javaName, nativeName, varInfo, rootJavaName, prefixElementNames);
}

public Constant addGlobalVarHandle(String javaName, String nativeName, MemoryLayout layout, Class<?> type) {
return addVarHandle(javaName, nativeName, layout, type, null, List.of());
public Constant addGlobalVarHandle(String javaName, String nativeName, VarInfo varInfo) {
return addVarHandle(javaName, nativeName, varInfo, null, List.of());
}

private Constant addVarHandle(String javaName, String nativeName, MemoryLayout layout, Class<?> type,
private Constant addVarHandle(String javaName, String nativeName, VarInfo varInfo,
String rootLayoutName, List<String> prefixElementNames) {
return emitIfAbsent(javaName, Constant.Kind.VAR_HANDLE,
() -> emitVarHandleField(javaName, nativeName, type, layout, rootLayoutName, prefixElementNames));
() -> emitVarHandleField(javaName, nativeName, varInfo, rootLayoutName, prefixElementNames));
}

public MethodHandleConstant addMethodHandle(String javaName, String nativeName, MethodType mtype, FunctionDescriptor desc, boolean virtual, boolean varargs) {
return (MethodHandleConstant)emitIfAbsent(javaName, Constant.Kind.METHOD_HANDLE,
() -> emitMethodHandleField(javaName, nativeName, mtype, desc, virtual, varargs));
public Constant addMethodHandle(String javaName, String nativeName, FunctionInfo functionInfo, boolean virtual) {
return emitIfAbsent(javaName, Constant.Kind.METHOD_HANDLE,
() -> emitMethodHandleField(javaName, nativeName, functionInfo, virtual));
}

public Constant addSegment(String javaName, String nativeName, MemoryLayout layout) {
@@ -169,54 +169,6 @@ Constant emitGetter(JavaSourceBuilder builder, String mods, Function<List<String
l -> l.get(2);
}

static class MethodHandleConstant extends Constant {

final MethodType mtype;
final boolean virtual;
final boolean varargs;


MethodHandleConstant(String className, String javaName, Kind kind, MethodType mtype, boolean virtual, boolean varargs) {
super(className, javaName, kind);
this.mtype = mtype;
this.virtual = virtual;
this.varargs = varargs;
}

@Override
MethodHandleConstant emitGetter(JavaSourceBuilder builder, String mods, Function<List<String>, String> getterNameFunc) {
return (MethodHandleConstant)super.emitGetter(builder, mods, getterNameFunc);
}

@Override
MethodHandleConstant emitGetter(JavaSourceBuilder builder, String mods, Function<List<String>, String> getterNameFunc, String symbolName) {
return (MethodHandleConstant)super.emitGetter(builder, mods, getterNameFunc, symbolName);
}

MethodHandleConstant emitFunction(JavaSourceBuilder builder, String mods, Function<List<String>, String> getterNameFunc,
List<String> paramNames) {
if (virtual) {
builder.emitVirtualFunctionWrapper(mods, mtype, getterNameFunc.apply(getterNameParts()), accessExpression());
} else {
builder.emitFunctionWrapper(mods, mtype, getterNameFunc.apply(getterNameParts()), accessExpression(),
varargs, paramNames);
}
return this;
}

MethodHandleConstant emitFunction(JavaSourceBuilder builder, String mods, Function<List<String>, String> getterNameFunc,
List<String> paramNames, String symbolName) {
if (virtual) {
builder.emitVirtualFunctionWrapper(mods, mtype, getterNameFunc.apply(getterNameParts()), accessExpression(),
true, symbolName);
} else {
builder.emitFunctionWrapper(mods, mtype, getterNameFunc.apply(getterNameParts()), accessExpression(),
varargs, paramNames, true, symbolName);
}
return this;
}
}

// private generators

public Constant emitIfAbsent(String name, Constant.Kind kind, Supplier<Constant> constantFactory) {
@@ -233,9 +185,8 @@ public Constant emitIfAbsent(String name, Constant.Kind kind, Supplier<Constant>
return constant;
}

private MethodHandleConstant emitMethodHandleField(String javaName, String nativeName, MethodType mtype,
FunctionDescriptor desc, boolean virtual, boolean varargs) {
Constant functionDesc = addFunctionDesc(javaName, desc);
private Constant emitMethodHandleField(String javaName, String nativeName, FunctionInfo functionInfo, boolean virtual) {
Constant functionDesc = addFunctionDesc(javaName, functionInfo.descriptor());
incrAlign();
String fieldName = Constant.Kind.METHOD_HANDLE.fieldName(javaName);
indent();
@@ -248,27 +199,27 @@ private MethodHandleConstant emitMethodHandleField(String javaName, String nativ
append(",\n");
indent();
}
append("\"" + mtype.toMethodDescriptorString() + "\",\n");
append("\"" + functionInfo.methodType().toMethodDescriptorString() + "\",\n");
indent();
append(functionDesc.accessExpression());
append(", ");
// isVariadic
append(varargs);
append(functionInfo.isVarargs());
append("\n");
decrAlign();
indent();
append(");\n");
decrAlign();
return new MethodHandleConstant(className(), javaName, Constant.Kind.METHOD_HANDLE, mtype, virtual, varargs);
return new Constant(className(), javaName, Constant.Kind.METHOD_HANDLE);
}

private Constant emitVarHandleField(String javaName, String nativeName, Class<?> type, MemoryLayout layout,
private Constant emitVarHandleField(String javaName, String nativeName, VarInfo varInfo,
String rootLayoutName, List<String> prefixElementNames) {
String layoutAccess = rootLayoutName != null ?
Constant.Kind.LAYOUT.fieldName(rootLayoutName) :
addLayout(javaName, layout).accessExpression();
addLayout(javaName, varInfo.layout()).accessExpression();
incrAlign();
String typeName = type.getName();
String typeName = varInfo.carrier().getName();
boolean isAddr = typeName.contains("MemoryAddress");
if (isAddr) {
typeName = "long";
@@ -24,6 +24,7 @@
*/
package jdk.internal.jextract.impl;

import jdk.incubator.foreign.Addressable;
import jdk.incubator.foreign.FunctionDescriptor;
import jdk.incubator.foreign.MemoryAddress;
import jdk.incubator.foreign.MemoryLayout;
@@ -34,6 +35,7 @@

import java.lang.constant.ClassDesc;
import java.lang.invoke.MethodType;
import java.util.ArrayList;
import java.util.List;

/**
@@ -58,41 +60,38 @@ String superClass() {
}

@Override
public void addVar(String javaName, String nativeName, MemoryLayout layout, Class<?> type) {
if (type.equals(MemorySegment.class)) {
public void addVar(String javaName, String nativeName, VarInfo varInfo) {
if (varInfo.carrier().equals(MemorySegment.class)) {
emitWithConstantClass(javaName, constantBuilder -> {
constantBuilder.addSegment(javaName, nativeName, layout)
constantBuilder.addSegment(javaName, nativeName, varInfo.layout())
.emitGetter(this, MEMBER_MODS, Constant.QUALIFIED_NAME, nativeName);
});
} else {
emitWithConstantClass(javaName, constantBuilder -> {
constantBuilder.addLayout(javaName, layout)
constantBuilder.addLayout(javaName, varInfo.layout())
.emitGetter(this, MEMBER_MODS, Constant.QUALIFIED_NAME);
Constant vhConstant = constantBuilder.addGlobalVarHandle(javaName, nativeName, layout, type)
Constant vhConstant = constantBuilder.addGlobalVarHandle(javaName, nativeName, varInfo)
.emitGetter(this, MEMBER_MODS, Constant.QUALIFIED_NAME);
Constant segmentConstant = constantBuilder.addSegment(javaName, nativeName, layout)
Constant segmentConstant = constantBuilder.addSegment(javaName, nativeName, varInfo.layout())
.emitGetter(this, MEMBER_MODS, Constant.QUALIFIED_NAME, nativeName);
emitGlobalGetter(segmentConstant, vhConstant, javaName, nativeName, type);
emitGlobalSetter(segmentConstant, vhConstant, javaName, nativeName, type);
emitGlobalGetter(segmentConstant, vhConstant, javaName, nativeName, varInfo.carrier());
emitGlobalSetter(segmentConstant, vhConstant, javaName, nativeName, varInfo.carrier());
if (varInfo.functionInfo().isPresent()) {
FunctionInfo functionInfo = varInfo.functionInfo().get();
Constant mhConstant = constantBuilder.addMethodHandle(javaName, nativeName, functionInfo, true)
.emitGetter(this, MEMBER_MODS, Constant.QUALIFIED_NAME, nativeName);
emitVirtualFunctionWrapper(mhConstant, javaName, functionInfo.methodType());
}
});
}
}

@Override
public void addFunction(String javaName, String nativeName, MethodType mtype, FunctionDescriptor desc, boolean varargs, List<String> paramNames) {
public void addFunction(String javaName, String nativeName, FunctionInfo functionInfo) {
emitWithConstantClass(javaName, constantBuilder -> {
constantBuilder.addMethodHandle(javaName, nativeName, mtype, desc, false, varargs)
.emitGetter(this, MEMBER_MODS, Constant.QUALIFIED_NAME, nativeName)
.emitFunction(this, MEMBER_MODS, Constant.JAVA_NAME, paramNames, nativeName);
});
}

@Override
public void addVirtualFunction(String javaName, String nativeName, MethodType mtype, FunctionDescriptor desc) {
emitWithConstantClass(javaName, constantBuilder -> {
constantBuilder.addMethodHandle(javaName, nativeName, mtype, desc, true, false)
.emitGetter(this, MEMBER_MODS, Constant.QUALIFIED_NAME)
.emitFunction(this, MEMBER_MODS, Constant.JAVA_NAME, null);
Constant mhConstant = constantBuilder.addMethodHandle(javaName, nativeName, functionInfo, false)
.emitGetter(this, MEMBER_MODS, Constant.QUALIFIED_NAME, nativeName);
emitFunctionWrapper(mhConstant, javaName, nativeName, functionInfo);
});
}

@@ -122,6 +121,125 @@ public void addTypedef(String name, String superClass, Type type) {

// private generation

private void emitFunctionWrapper(Constant mhConstant, String javaName, String nativeName, FunctionInfo functionInfo) {
incrAlign();
indent();
append(MEMBER_MODS + " ");
append(functionInfo.methodType().returnType().getSimpleName() + " " + javaName + " (");
String delim = "";
List<String> pExprs = new ArrayList<>();
List<String> paramNames = functionInfo.parameterNames().get();
final int numParams = paramNames.size();
for (int i = 0 ; i < numParams; i++) {
String pName = paramNames.get(i);
if (pName.isEmpty()) {
pName = "x" + i;
}
if (functionInfo.methodType().parameterType(i).equals(MemoryAddress.class)) {
pExprs.add(pName + ".address()");
} else {
pExprs.add(pName);
}
Class<?> pType = functionInfo.methodType().parameterType(i);
if (pType.equals(MemoryAddress.class)) {
pType = Addressable.class;
}
append(delim + " " + pType.getSimpleName() + " " + pName);
delim = ", ";
}
if (functionInfo.isVarargs()) {
String lastArg = "x" + numParams;
if (numParams > 0) {
append(", ");
}
append("Object... " + lastArg);
pExprs.add(lastArg);
}
append(") {\n");
incrAlign();
indent();
append("var mh$ = RuntimeHelper.requireNonNull(");
append(mhConstant.accessExpression());
append(", \"");
append(nativeName);
append("\");\n");
indent();
append("try {\n");
incrAlign();
indent();
if (!functionInfo.methodType().returnType().equals(void.class)) {
append("return (" + functionInfo.methodType().returnType().getName() + ")");
}
append("mh$.invokeExact(" + String.join(", ", pExprs) + ");\n");
decrAlign();
indent();
append("} catch (Throwable ex$) {\n");
incrAlign();
indent();
append("throw new AssertionError(\"should not reach here\", ex$);\n");
decrAlign();
indent();
append("}\n");
decrAlign();
indent();
append("}\n");
decrAlign();
}

private void emitVirtualFunctionWrapper(Constant mhConstant, String javaName, MethodType mtype) {
incrAlign();
indent();
append(MEMBER_MODS + " ");
append(mtype.returnType().getSimpleName() + " " + javaName + " (");
String delim = "";
List<String> pExprs = new ArrayList<>();
int numParams = mtype.parameterCount();
for (int i = 0 ; i < numParams; i++) {
String pName = "x" + i;
if (mtype.parameterType(i).equals(MemoryAddress.class)) {
pExprs.add(pName + ".address()");
} else {
pExprs.add(pName);
}
Class<?> pType = mtype.parameterType(i);
if (pType.equals(MemoryAddress.class)) {
pType = Addressable.class;
}
append(delim + " " + pType.getSimpleName() + " " + pName);
delim = ", ";
}
append(") {\n");
incrAlign();
indent();
append("var mh$ = ");
append(mhConstant.accessExpression());
append(";\n");
indent();
append("try {\n");
incrAlign();
indent();
if (!mtype.returnType().equals(void.class)) {
append("return (" + mtype.returnType().getName() + ")");
}
append("mh$.invokeExact(");
append("(Addressable)");
append(javaName + "$get(), ");
append(String.join(", ", pExprs) + ");\n");
decrAlign();
indent();
append("} catch (Throwable ex$) {\n");
incrAlign();
indent();
append("throw new AssertionError(\"should not reach here\", ex$);\n");
decrAlign();
indent();
append("}\n");
decrAlign();
indent();
append("}\n");
decrAlign();
}

private void emitPrimitiveTypedef(Type.Primitive primType, String name) {
Type.Primitive.Kind kind = primType.kind();
if (primitiveKindSupported(kind) && !kind.layout().isEmpty()) {
Loading