/
RuntimeHelper.java.template
240 lines (209 loc) · 9.37 KB
/
RuntimeHelper.java.template
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
// Generated by jextract
import java.lang.foreign.Linker;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.GroupLayout;
import java.lang.foreign.SymbolLookup;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.Arena;
import java.lang.foreign.SegmentAllocator;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.io.File;
import java.nio.file.Path;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Optional;
import java.util.stream.Stream;
import static java.lang.foreign.Linker.*;
import static java.lang.foreign.ValueLayout.*;
final class RuntimeHelper {
private static final Linker LINKER = Linker.nativeLinker();
private static final ClassLoader LOADER = RuntimeHelper.class.getClassLoader();
private static final MethodHandles.Lookup MH_LOOKUP = MethodHandles.lookup();
private static final SymbolLookup SYMBOL_LOOKUP;
private static final SegmentAllocator THROWING_ALLOCATOR = (x, y) -> { throw new AssertionError("should not reach here"); };
final static SegmentAllocator CONSTANT_ALLOCATOR =
(size, align) -> Arena.ofAuto().allocate(size, align);
static {
#LOAD_LIBRARIES#
SymbolLookup loaderLookup = SymbolLookup.loaderLookup();
SYMBOL_LOOKUP = name -> loaderLookup.find(name).or(() -> LINKER.defaultLookup().find(name));
}
// Suppresses default constructor, ensuring non-instantiability.
private RuntimeHelper() {}
static <T> T requireNonNull(T obj, String symbolName) {
if (obj == null) {
throw new UnsatisfiedLinkError("unresolved symbol: " + symbolName);
}
return obj;
}
static MemorySegment lookupGlobalVariable(String name, MemoryLayout layout) {
return SYMBOL_LOOKUP.find(name)
.map(s -> s.reinterpret(layout.byteSize()))
.orElse(null);
}
static MethodHandle downcallHandle(String name, FunctionDescriptor fdesc) {
return SYMBOL_LOOKUP.find(name).
map(addr -> LINKER.downcallHandle(addr, fdesc)).
orElse(null);
}
static MethodHandle downcallHandle(FunctionDescriptor fdesc) {
return LINKER.downcallHandle(fdesc);
}
static MethodHandle downcallHandleVariadic(String name, FunctionDescriptor fdesc) {
return SYMBOL_LOOKUP.find(name).
map(addr -> VarargsInvoker.make(addr, fdesc)).
orElse(null);
}
static MethodHandle upcallHandle(Class<?> fi, String name, FunctionDescriptor fdesc) {
try {
return MH_LOOKUP.findVirtual(fi, name, fdesc.toMethodType());
} catch (Throwable ex) {
throw new AssertionError(ex);
}
}
static <Z> MemorySegment upcallStub(MethodHandle fiHandle, Z z, FunctionDescriptor fdesc, Arena scope) {
try {
fiHandle = fiHandle.bindTo(z);
return LINKER.upcallStub(fiHandle, fdesc, scope);
} catch (Throwable ex) {
throw new AssertionError(ex);
}
}
static MemorySegment asArray(MemorySegment addr, MemoryLayout layout, int numElements, Arena arena) {
return addr.reinterpret(numElements * layout.byteSize(), arena.scope(), null);
}
// Internals only below this point
private static final class VarargsInvoker {
private static final MethodHandle INVOKE_MH;
private final MemorySegment symbol;
private final FunctionDescriptor function;
private VarargsInvoker(MemorySegment symbol, FunctionDescriptor function) {
this.symbol = symbol;
this.function = function;
}
static {
try {
INVOKE_MH = MethodHandles.lookup().findVirtual(VarargsInvoker.class, "invoke", MethodType.methodType(Object.class, SegmentAllocator.class, Object[].class));
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}
static MethodHandle make(MemorySegment symbol, FunctionDescriptor function) {
VarargsInvoker invoker = new VarargsInvoker(symbol, function);
MethodHandle handle = INVOKE_MH.bindTo(invoker).asCollector(Object[].class, function.argumentLayouts().size() + 1);
MethodType mtype = MethodType.methodType(function.returnLayout().isPresent() ? carrier(function.returnLayout().get(), true) : void.class);
for (MemoryLayout layout : function.argumentLayouts()) {
mtype = mtype.appendParameterTypes(carrier(layout, false));
}
mtype = mtype.appendParameterTypes(Object[].class);
boolean needsAllocator = function.returnLayout().isPresent() &&
function.returnLayout().get() instanceof GroupLayout;
if (needsAllocator) {
mtype = mtype.insertParameterTypes(0, SegmentAllocator.class);
} else {
handle = MethodHandles.insertArguments(handle, 0, THROWING_ALLOCATOR);
}
return handle.asType(mtype);
}
static Class<?> carrier(MemoryLayout layout, boolean ret) {
if (layout instanceof ValueLayout valueLayout) {
return valueLayout.carrier();
} else if (layout instanceof GroupLayout) {
return MemorySegment.class;
} else {
throw new AssertionError("Cannot get here!");
}
}
private Object invoke(SegmentAllocator allocator, Object[] args) throws Throwable {
// one trailing Object[]
int nNamedArgs = function.argumentLayouts().size();
assert(args.length == nNamedArgs + 1);
// The last argument is the array of vararg collector
Object[] unnamedArgs = (Object[]) args[args.length - 1];
int argsCount = nNamedArgs + unnamedArgs.length;
Class<?>[] argTypes = new Class<?>[argsCount];
MemoryLayout[] argLayouts = new MemoryLayout[nNamedArgs + unnamedArgs.length];
int pos = 0;
for (pos = 0; pos < nNamedArgs; pos++) {
argLayouts[pos] = function.argumentLayouts().get(pos);
}
assert pos == nNamedArgs;
for (Object o: unnamedArgs) {
argLayouts[pos] = variadicLayout(normalize(o.getClass()));
pos++;
}
assert pos == argsCount;
FunctionDescriptor f = (function.returnLayout().isEmpty()) ?
FunctionDescriptor.ofVoid(argLayouts) :
FunctionDescriptor.of(function.returnLayout().get(), argLayouts);
MethodHandle mh = LINKER.downcallHandle(symbol, f);
boolean needsAllocator = function.returnLayout().isPresent() &&
function.returnLayout().get() instanceof GroupLayout;
if (needsAllocator) {
mh = mh.bindTo(allocator);
}
// flatten argument list so that it can be passed to an asSpreader MH
Object[] allArgs = new Object[nNamedArgs + unnamedArgs.length];
System.arraycopy(args, 0, allArgs, 0, nNamedArgs);
System.arraycopy(unnamedArgs, 0, allArgs, nNamedArgs, unnamedArgs.length);
return mh.asSpreader(Object[].class, argsCount).invoke(allArgs);
}
private static Class<?> unboxIfNeeded(Class<?> clazz) {
if (clazz == Boolean.class) {
return boolean.class;
} else if (clazz == Void.class) {
return void.class;
} else if (clazz == Byte.class) {
return byte.class;
} else if (clazz == Character.class) {
return char.class;
} else if (clazz == Short.class) {
return short.class;
} else if (clazz == Integer.class) {
return int.class;
} else if (clazz == Long.class) {
return long.class;
} else if (clazz == Float.class) {
return float.class;
} else if (clazz == Double.class) {
return double.class;
} else {
return clazz;
}
}
private Class<?> promote(Class<?> c) {
if (c == byte.class || c == char.class || c == short.class || c == int.class) {
return long.class;
} else if (c == float.class) {
return double.class;
} else {
return c;
}
}
private Class<?> normalize(Class<?> c) {
c = unboxIfNeeded(c);
if (c.isPrimitive()) {
return promote(c);
}
if (c == MemorySegment.class) {
return MemorySegment.class;
}
throw new IllegalArgumentException("Invalid type for ABI: " + c.getTypeName());
}
private MemoryLayout variadicLayout(Class<?> c) {
if (c == long.class) {
return JAVA_LONG;
} else if (c == double.class) {
return JAVA_DOUBLE;
} else if (c == MemorySegment.class) {
return ADDRESS;
} else {
throw new IllegalArgumentException("Unhandled variadic argument class: " + c);
}
}
}
}