diff --git a/wasm/CHANGELOG.md b/wasm/CHANGELOG.md index ae0151d1fb8e..1a38a9544fb3 100644 --- a/wasm/CHANGELOG.md +++ b/wasm/CHANGELOG.md @@ -5,6 +5,7 @@ This changelog summarizes major changes to the WebAssembly engine implemented in ## Version 25.1.0 * Implemented the [exception handling](https://github.com/WebAssembly/exception-handling) proposal. This feature can be enabled with the experimental option `wasm.Exceptions=true`. +* Implemented the [typed function references](https://github.com/WebAssembly/function-references) proposal. This feature can be enabled with the experimental option `wasm.TypedFunctionReferences=true`. ## Version 25.0.0 diff --git a/wasm/src/org.graalvm.wasm.benchmark/src/org/graalvm/wasm/benchmark/MemoryFootprintBenchmarkRunner.java b/wasm/src/org.graalvm.wasm.benchmark/src/org/graalvm/wasm/benchmark/MemoryFootprintBenchmarkRunner.java index 1e22bf48220d..d175735884f0 100644 --- a/wasm/src/org.graalvm.wasm.benchmark/src/org/graalvm/wasm/benchmark/MemoryFootprintBenchmarkRunner.java +++ b/wasm/src/org.graalvm.wasm.benchmark/src/org/graalvm/wasm/benchmark/MemoryFootprintBenchmarkRunner.java @@ -169,7 +169,7 @@ static double getHeapSize() { static void sleep() { try { - Thread.sleep(2000); + Thread.sleep(5000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/AbstractBinarySuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/AbstractBinarySuite.java index bc065856b093..358b3f203c3a 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/AbstractBinarySuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/AbstractBinarySuite.java @@ -53,9 +53,10 @@ import org.graalvm.polyglot.io.ByteSequence; import org.graalvm.wasm.WasmLanguage; import org.graalvm.wasm.collection.ByteArrayList; +import org.graalvm.wasm.collection.IntArrayList; public abstract class AbstractBinarySuite { - protected static final byte[] EMPTY_BYTES = {}; + protected static final int[] EMPTY_INTS = {}; protected static void runRuntimeTest(byte[] binary, Consumer options, Consumer testCase) throws IOException { final Context.Builder contextBuilder = Context.newBuilder(WasmLanguage.ID); @@ -100,10 +101,10 @@ private static byte getByte(String hexString) { private static final class BinaryTypes { - private final List paramEntries = new ArrayList<>(); - private final List resultEntries = new ArrayList<>(); + private final List paramEntries = new ArrayList<>(); + private final List resultEntries = new ArrayList<>(); - private void add(byte[] params, byte[] results) { + private void add(int[] params, int[] results) { paramEntries.add(params); resultEntries.add(results); } @@ -112,18 +113,18 @@ private byte[] generateTypeSection() { ByteArrayList b = new ByteArrayList(); b.add(getByte("01")); b.add((byte) 0); // length is patched in the end - b.add((byte) paramEntries.size()); + b.addUnsignedInt32(paramEntries.size()); for (int i = 0; i < paramEntries.size(); i++) { b.add(getByte("60")); - byte[] params = paramEntries.get(i); - byte[] results = resultEntries.get(i); - b.add((byte) params.length); - for (byte param : params) { - b.add(param); + int[] params = paramEntries.get(i); + int[] results = resultEntries.get(i); + b.addUnsignedInt32(params.length); + for (int param : params) { + b.addSignedInt32(param); } - b.add((byte) results.length); - for (byte result : results) { - b.add(result); + b.addUnsignedInt32(results.length); + for (int result : results) { + b.addSignedInt32(result); } } b.set(1, (byte) (b.size() - 2)); @@ -132,9 +133,9 @@ private byte[] generateTypeSection() { } private static final class BinaryTables { - private final ByteArrayList tables = new ByteArrayList(); + private final IntArrayList tables = new IntArrayList(); - private void add(byte initSize, byte maxSize, byte elemType) { + private void add(int initSize, int maxSize, int elemType) { tables.add(initSize); tables.add(maxSize); tables.add(elemType); @@ -147,10 +148,10 @@ private byte[] generateTableSection() { final int tableCount = tables.size() / 3; b.add((byte) tableCount); for (int i = 0; i < tables.size(); i += 3) { - b.add(tables.get(i + 2)); + b.addSignedInt32(tables.get(i + 2)); b.add(getByte("01")); - b.add(tables.get(i)); - b.add(tables.get(i + 1)); + b.addUnsignedInt32(tables.get(i)); + b.addUnsignedInt32(tables.get(i + 1)); } b.set(1, (byte) (b.size() - 2)); return b.toArray(); @@ -158,9 +159,9 @@ private byte[] generateTableSection() { } private static final class BinaryMemories { - private final ByteArrayList memories = new ByteArrayList(); + private final IntArrayList memories = new IntArrayList(); - private void add(byte initSize, byte maxSize) { + private void add(int initSize, int maxSize) { memories.add(initSize); memories.add(maxSize); } @@ -173,8 +174,8 @@ private byte[] generateMemorySection() { b.add((byte) memoryCount); for (int i = 0; i < memories.size(); i += 2) { b.add(getByte("01")); - b.add(memories.get(i)); - b.add(memories.get(i + 1)); + b.addUnsignedInt32(memories.get(i)); + b.addUnsignedInt32(memories.get(i + 1)); } b.set(1, (byte) (b.size() - 2)); return b.toArray(); @@ -182,11 +183,11 @@ private byte[] generateMemorySection() { } private static final class BinaryFunctions { - private final ByteArrayList types = new ByteArrayList(); - private final List localEntries = new ArrayList<>(); + private final IntArrayList types = new IntArrayList(); + private final List localEntries = new ArrayList<>(); private final List codeEntries = new ArrayList<>(); - private void add(byte typeIndex, byte[] locals, byte[] code) { + private void add(int typeIndex, int[] locals, byte[] code) { types.add(typeIndex); localEntries.add(locals); codeEntries.add(code); @@ -197,9 +198,9 @@ private byte[] generateFunctionSection() { b.add(getByte("03")); b.add((byte) 0); // length is patched at the end final int functionCount = types.size(); - b.add((byte) functionCount); + b.addUnsignedInt32(functionCount); for (int i = 0; i < functionCount; i++) { - b.add(types.get(i)); + b.addUnsignedInt32(types.get(i)); } b.set(1, (byte) (b.size() - 2)); return b.toArray(); @@ -210,15 +211,15 @@ private byte[] generateCodeSection() { b.add(getByte("0A")); b.add((byte) 0); // length is patched at the end final int functionCount = types.size(); - b.add((byte) functionCount); + b.addUnsignedInt32(functionCount); for (int i = 0; i < functionCount; i++) { - byte[] locals = localEntries.get(i); + int[] locals = localEntries.get(i); byte[] code = codeEntries.get(i); int length = 1 + locals.length + code.length; - b.add((byte) length); - b.add((byte) locals.length); - for (byte l : locals) { - b.add(l); + b.addUnsignedInt32(length); + b.addUnsignedInt32(locals.length); + for (int l : locals) { + b.addSignedInt32(l); } for (byte op : code) { b.add(op); @@ -231,10 +232,10 @@ private byte[] generateCodeSection() { private static final class BinaryExports { private final ByteArrayList types = new ByteArrayList(); - private final ByteArrayList indices = new ByteArrayList(); + private final IntArrayList indices = new IntArrayList(); private final List names = new ArrayList<>(); - private void addFunctionExport(byte functionIndex, String name) { + private void addFunctionExport(int functionIndex, String name) { types.add(getByte("00")); indices.add(functionIndex); names.add(name.getBytes(StandardCharsets.UTF_8)); @@ -244,15 +245,15 @@ private byte[] generateExportSection() { ByteArrayList b = new ByteArrayList(); b.add(getByte("07")); b.add((byte) 0); // length is patched at the end - b.add((byte) types.size()); + b.addUnsignedInt32(types.size()); for (int i = 0; i < types.size(); i++) { final byte[] name = names.get(i); - b.add((byte) name.length); + b.addUnsignedInt32(name.length); for (byte value : name) { b.add(value); } b.add(types.get(i)); - b.add(indices.get(i)); + b.addUnsignedInt32(indices.get(i)); } b.set(1, (byte) (b.size() - 2)); return b.toArray(); @@ -313,10 +314,10 @@ private byte[] generateDataSection() { private static final class BinaryGlobals { private final ByteArrayList mutabilities = new ByteArrayList(); - private final ByteArrayList valueTypes = new ByteArrayList(); + private final IntArrayList valueTypes = new IntArrayList(); private final List expressions = new ArrayList<>(); - private void add(byte mutability, byte valueType, byte[] expression) { + private void add(byte mutability, int valueType, byte[] expression) { mutabilities.add(mutability); valueTypes.add(valueType); expressions.add(expression); @@ -328,7 +329,7 @@ private byte[] generateGlobalSection() { b.add((byte) 0); // length is patched at the end b.add((byte) mutabilities.size()); for (int i = 0; i < mutabilities.size(); i++) { - b.add(valueTypes.get(i)); + b.addSignedInt32(valueTypes.get(i)); b.add(mutabilities.get(i)); for (byte e : expressions.get(i)) { b.add(e); @@ -378,8 +379,8 @@ private byte[] generateCustomSections() { final byte[] name = names.get(i); final byte[] section = sections.get(i); final int size = 1 + name.length + section.length; - b.add((byte) size); // length is patched at the end - b.add((byte) name.length); + b.addUnsignedInt32(size); + b.addUnsignedInt32(name.length); b.addRange(name, 0, name.length); b.addRange(section, 0, section.length); } @@ -400,27 +401,27 @@ protected static class BinaryBuilder { private final BinaryCustomSections binaryCustomSections = new BinaryCustomSections(); - public BinaryBuilder addType(byte[] params, byte[] results) { + public BinaryBuilder addType(int[] params, int[] results) { binaryTypes.add(params, results); return this; } - public BinaryBuilder addTable(byte initSize, byte maxSize, byte elemType) { + public BinaryBuilder addTable(int initSize, int maxSize, int elemType) { binaryTables.add(initSize, maxSize, elemType); return this; } - public BinaryBuilder addMemory(byte initSize, byte maxSize) { + public BinaryBuilder addMemory(int initSize, int maxSize) { binaryMemories.add(initSize, maxSize); return this; } - public BinaryBuilder addFunction(byte typeIndex, byte[] locals, String hexCode) { + public BinaryBuilder addFunction(int typeIndex, int[] locals, String hexCode) { binaryFunctions.add(typeIndex, locals, WasmTestUtils.hexStringToByteArray(hexCode)); return this; } - public BinaryBuilder addFunctionExport(byte functionIndex, String name) { + public BinaryBuilder addFunctionExport(int functionIndex, String name) { binaryExports.addFunctionExport(functionIndex, name); return this; } @@ -435,7 +436,7 @@ public BinaryBuilder addData(String hexCode) { return this; } - public BinaryBuilder addGlobal(byte mutability, byte valueType, String hexCode) { + public BinaryBuilder addGlobal(byte mutability, int valueType, String hexCode) { binaryGlobals.add(mutability, valueType, WasmTestUtils.hexStringToByteArray(hexCode)); return this; } diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java index 3d8e933eccd7..6959417dfda9 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java @@ -102,7 +102,7 @@ public class WasmJsApiSuite { private static final String REF_TYPES_OPTION = "wasm.BulkMemoryAndRefTypes"; - private static WasmFunctionInstance createWasmFunctionInstance(WasmContext context, byte[] paramTypes, byte[] resultTypes, RootNode functionRootNode) { + private static WasmFunctionInstance createWasmFunctionInstance(WasmContext context, int[] paramTypes, int[] resultTypes, RootNode functionRootNode) { WasmModule module = WasmModule.createBuiltin("dummyModule"); module.allocateFunctionType(paramTypes, resultTypes, context.getContextOptions().supportMultiValue()); WasmFunction func = module.declareFunction(0); @@ -111,7 +111,6 @@ private static WasmFunctionInstance createWasmFunctionInstance(WasmContext conte // Perform normal linking steps, incl. assignTypeEquivalenceClasses(). // Functions need to have type equivalence classes assigned for indirect calls. moduleInstance.store().linker().tryLink(moduleInstance); - assert func.typeEquivalenceClass() >= 0 : "type equivalence class must be assigned"; return new WasmFunctionInstance(moduleInstance, func, functionRootNode.getCallTarget()); } @@ -503,7 +502,7 @@ public void testGlobalWriteNull() throws IOException { public void testGlobalWriteAnyfuncRefTypesDisabled() throws IOException { runTest(WasmJsApiSuite::disableRefTypes, context -> { final WebAssembly wasm = new WebAssembly(context); - final WasmGlobal global = new WasmGlobal(ValueType.anyfunc, true, WasmConstant.NULL); + final WasmGlobal global = WasmGlobal.allocRef(ValueType.anyfunc, true, WasmConstant.NULL); try { wasm.globalWrite(global, WasmConstant.NULL); Assert.fail("Should have failed - ref types not enabled"); @@ -517,7 +516,7 @@ public void testGlobalWriteAnyfuncRefTypesDisabled() throws IOException { public void testGlobalWriteExternrefRefTypesDisabled() throws IOException { runTest(WasmJsApiSuite::disableRefTypes, context -> { final WebAssembly wasm = new WebAssembly(context); - final WasmGlobal global = new WasmGlobal(ValueType.externref, true, WasmConstant.NULL); + final WasmGlobal global = WasmGlobal.allocRef(ValueType.externref, true, WasmConstant.NULL); try { wasm.globalWrite(global, WasmConstant.NULL); Assert.fail("Should have failed - ref types not enabled"); @@ -1412,7 +1411,14 @@ public void testFuncTypeMultiValue() throws IOException, InterruptedException { @Test public void testMultiValueReferencePassThrough() throws IOException, InterruptedException { - final byte[] source = compileWat("data", """ + final byte[] source1 = compileWat("data", """ + (module + (type (func (result i32))) + (func (export "func") (type 0) + i32.const 42 + )) + """); + final byte[] source2 = compileWat("data", """ (module (type (func (result funcref externref))) (import "m" "f" (func (type 0))) @@ -1422,7 +1428,8 @@ public void testMultiValueReferencePassThrough() throws IOException, Interrupted """); runTest(context -> { final WebAssembly wasm = new WebAssembly(context); - final var func = new Executable((args) -> 0); + final WasmInstance instance1 = moduleInstantiate(wasm, source1, null); + final Object func = WebAssembly.instanceExport(instance1, "func"); final var f = new Executable((args) -> { final Object[] result = new Object[2]; result[0] = func; @@ -1430,8 +1437,8 @@ public void testMultiValueReferencePassThrough() throws IOException, Interrupted return InteropArray.create(result); }); final Dictionary importObject = Dictionary.create(new Object[]{"m", Dictionary.create(new Object[]{"f", f})}); - final WasmInstance instance = moduleInstantiate(wasm, source, importObject); - final Object main = WebAssembly.instanceExport(instance, "main"); + final WasmInstance instance2 = moduleInstantiate(wasm, source2, importObject); + final Object main = WebAssembly.instanceExport(instance2, "main"); final InteropLibrary lib = InteropLibrary.getUncached(); try { Object result = lib.execute(main); diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java index e19969a153fa..2eca736fdadd 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java @@ -159,7 +159,7 @@ public void testInvalidResultType() { @Test public void testBrU8Min() { - test(b -> b.addBranch(1), new byte[]{Bytecode.BR_U8, 0x00}); + test(b -> b.addBranch(1, RuntimeBytecodeGen.BranchOp.BR), new byte[]{Bytecode.BR_U8, 0x00}); } @Test @@ -171,18 +171,18 @@ public void testBrU8Max() { for (int i = 0; i < 254; i++) { b.addOp(0); } - b.addBranch(0); + b.addBranch(0, RuntimeBytecodeGen.BranchOp.BR); }, expected); } @Test public void testBrI32MinForward() { - test(b -> b.addBranch(2), new byte[]{Bytecode.BR_I32, 0x01, 0x00, 0x00, 0x00}); + test(b -> b.addBranch(2, RuntimeBytecodeGen.BranchOp.BR), new byte[]{Bytecode.BR_I32, 0x01, 0x00, 0x00, 0x00}); } @Test public void testBrI32MaxForward() { - test(b -> b.addBranch(2147483647), new byte[]{Bytecode.BR_I32, (byte) 0xFE, (byte) 0xFF, (byte) 0xFF, 0x7F}); + test(b -> b.addBranch(2147483647, RuntimeBytecodeGen.BranchOp.BR), new byte[]{Bytecode.BR_I32, (byte) 0xFE, (byte) 0xFF, (byte) 0xFF, 0x7F}); } @Test @@ -197,13 +197,13 @@ public void testBrI32MinBackward() { for (int i = 0; i < 255; i++) { b.addOp(0); } - b.addBranch(0); + b.addBranch(0, RuntimeBytecodeGen.BranchOp.BR); }, expected); } @Test public void testBrIfU8Min() { - test(b -> b.addBranchIf(1), new byte[]{Bytecode.BR_IF_U8, 0x00, 0x00, 0x00}); + test(b -> b.addBranch(1, RuntimeBytecodeGen.BranchOp.BR_IF), new byte[]{Bytecode.BR_IF_U8, 0x00, 0x00, 0x00}); } @Test @@ -215,18 +215,18 @@ public void testBrIfU8Max() { for (int i = 0; i < 254; i++) { b.addOp(0); } - b.addBranchIf(0); + b.addBranch(0, RuntimeBytecodeGen.BranchOp.BR_IF); }, expected); } @Test public void testBrIfI32MinForward() { - test(b -> b.addBranchIf(2), new byte[]{Bytecode.BR_IF_I32, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}); + test(b -> b.addBranch(2, RuntimeBytecodeGen.BranchOp.BR_IF), new byte[]{Bytecode.BR_IF_I32, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}); } @Test public void testBrIfI32MaxForward() { - test(b -> b.addBranchIf(2147483647), new byte[]{Bytecode.BR_IF_I32, (byte) 0xFE, (byte) 0xFF, (byte) 0xFF, 0x7F, 0x00, 0x00}); + test(b -> b.addBranch(2147483647, RuntimeBytecodeGen.BranchOp.BR_IF), new byte[]{Bytecode.BR_IF_I32, (byte) 0xFE, (byte) 0xFF, (byte) 0xFF, 0x7F, 0x00, 0x00}); } @Test @@ -241,7 +241,7 @@ public void testBrIfI32MinBackward() { for (int i = 0; i < 255; i++) { b.addOp(0); } - b.addBranchIf(0); + b.addBranch(0, RuntimeBytecodeGen.BranchOp.BR_IF); }, expected); } @@ -644,137 +644,137 @@ public void testDataRuntimeHeaderMinI32Length() { @Test public void testElemHeaderMin() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, 0x00}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, WasmType.FUNCREF_TYPE, 0x00}); } @Test public void testElemHeaderMinU8Count() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 1, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, 0x01}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 1, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, WasmType.FUNCREF_TYPE, 0x01}); } @Test public void testElemHeaderMaxU8Count() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 255, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 255, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, WasmType.FUNCREF_TYPE, (byte) 0xFF}); } @Test public void testElemHeaderMinU16Count() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 256, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0x80, 0x10, 0x00, 0x01}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 256, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0x80, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x01}); } @Test public void testElemHeaderMaxU16Count() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 65535, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0x80, 0x10, (byte) 0xFF, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 65535, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0x80, 0x10, WasmType.FUNCREF_TYPE, (byte) 0xFF, (byte) 0xFF}); } @Test public void testElemHeaderMinI32Count() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 65536, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0xC0, 0x10, 0x00, 0x00, 0x01, 0x00}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 65536, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0xC0, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x01, 0x00}); } @Test public void testElemHeaderPassive() { - test(b -> b.addElemHeader(SegmentMode.PASSIVE, 8, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x11, 0x08}); + test(b -> b.addElemHeader(SegmentMode.PASSIVE, 8, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x11, WasmType.FUNCREF_TYPE, 0x08}); } @Test public void testElemHeaderDeclarative() { - test(b -> b.addElemHeader(SegmentMode.DECLARATIVE, 8, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x12, 0x08}); + test(b -> b.addElemHeader(SegmentMode.DECLARATIVE, 8, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x12, WasmType.FUNCREF_TYPE, 0x08}); } @Test public void testElemHeaderExternref() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 8, WasmType.EXTERNREF_TYPE, 0, null, -1), new byte[]{0x40, 0x20, 0x08}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 8, WasmType.EXTERNREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, WasmType.EXTERNREF_TYPE, 0x08}); } @Test public void testElemHeaderExnref() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 8, WasmType.EXNREF_TYPE, 0, null, -1), new byte[]{0x40, 0x30, 0x08}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 8, WasmType.EXNREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, WasmType.EXNREF_TYPE, 0x08}); } @Test public void testElemHeaderMinU8TableIndex() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 1, null, -1), new byte[]{0x50, 0x10, 0x00, 0x01}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 1, null, -1), new byte[]{0x50, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x01}); } @Test public void testElemHeaderMaxU8TableIndex() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 255, null, -1), new byte[]{0x50, 0x10, 0x00, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 255, null, -1), new byte[]{0x50, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF}); } @Test public void testElemHeaderMinU16TableIndex() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 256, null, -1), new byte[]{0x60, 0x10, 0x00, 0x00, 0x01}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 256, null, -1), new byte[]{0x60, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x01}); } @Test public void testElemHeaderMaxU16TableIndex() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 65535, null, -1), new byte[]{0x60, 0x10, 0x00, (byte) 0xFF, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 65535, null, -1), new byte[]{0x60, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF, (byte) 0xFF}); } @Test public void testElemHeaderMinI32TableIndex() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 65536, null, -1), new byte[]{0x70, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 65536, null, -1), new byte[]{0x70, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x00, 0x01, 0x00}); } @Test public void testElemHeaderMinU8OffsetBytecodeLength() { byte[] offsetBytecode = new byte[0]; test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, offsetBytecode, -1), - byteArrayConcat(new byte[]{0x44, 0x10, 0x00, 0x00}, offsetBytecode)); + byteArrayConcat(new byte[]{0x44, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00}, offsetBytecode)); } @Test public void testElemHeaderMaxU8OffsetBytecodeLength() { byte[] offsetBytecode = new byte[255]; test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, offsetBytecode, -1), - byteArrayConcat(new byte[]{0x44, 0x10, 0x00, (byte) 0xFF}, offsetBytecode)); + byteArrayConcat(new byte[]{0x44, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF}, offsetBytecode)); } @Test public void testElemHeaderMinU16OffsetBytecodeLength() { byte[] offsetBytecode = new byte[256]; test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, offsetBytecode, -1), - byteArrayConcat(new byte[]{0x48, 0x10, 0x00, 0x00, 0x01}, offsetBytecode)); + byteArrayConcat(new byte[]{0x48, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x01}, offsetBytecode)); } @Test public void testElemHeaderMaxU16OffsetBytecodeLength() { byte[] offsetBytecode = new byte[65535]; test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, offsetBytecode, -1), - byteArrayConcat(new byte[]{0x48, 0x10, 0x00, (byte) 0xFF, (byte) 0xFF}, offsetBytecode)); + byteArrayConcat(new byte[]{0x48, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF, (byte) 0xFF}, offsetBytecode)); } @Test public void testElemHeaderMinI32OffsetBytecodeLength() { byte[] offsetBytecode = new byte[65536]; test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, offsetBytecode, -1), - byteArrayConcat(new byte[]{0x4C, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00}, offsetBytecode)); + byteArrayConcat(new byte[]{0x4C, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x00, 0x01, 0x00}, offsetBytecode)); } @Test public void testElemHeaderMinU8OffsetAddress() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 0), new byte[]{0x41, 0x10, 0x00, 0x00}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 0), new byte[]{0x41, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00}); } @Test public void testElemHeaderMaxU8OffsetAddress() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 255), new byte[]{0x41, 0x10, 0x00, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 255), new byte[]{0x41, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF}); } @Test public void testElemHeaderMinU16OffsetAddress() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 256), new byte[]{0x42, 0x10, 0x00, 0x00, 0x01}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 256), new byte[]{0x42, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x01}); } @Test public void testElemHeaderMaxU16OffsetAddress() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 65535), new byte[]{0x42, 0x10, 0x00, (byte) 0xFF, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 65535), new byte[]{0x42, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF, (byte) 0xFF}); } @Test public void testElemHeaderMinI32OffsetAddress() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 65536), new byte[]{0x43, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 65536), new byte[]{0x43, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x00, 0x01, 0x00}); } @Test diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/debugging/DebugValidationSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/debugging/DebugValidationSuite.java index 3063bf52f758..d4ec02972d01 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/debugging/DebugValidationSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/debugging/DebugValidationSuite.java @@ -59,7 +59,7 @@ public class DebugValidationSuite extends AbstractBinarySuite { private static AbstractBinarySuite.BinaryBuilder getDefaultDebugBuilder() { - return newBuilder().addType(EMPTY_BYTES, EMPTY_BYTES).addFunction((byte) 0, EMPTY_BYTES, "0B").addFunctionExport((byte) 0, "_main"); + return newBuilder().addType(EMPTY_INTS, EMPTY_INTS).addFunction(0, EMPTY_INTS, "0B").addFunctionExport(0, "_main"); } @Test diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java index e5b22eb4fd84..31dcf9980adc 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java @@ -75,9 +75,9 @@ private static AbstractBinarySuite.BinaryBuilder getDefaultTableInitBuilder(Stri // ;; main hex code // ) // ) - return newBuilder().addType(EMPTY_BYTES, new byte[]{WasmType.I32_TYPE}).addTable((byte) 3, (byte) 3, WasmType.FUNCREF_TYPE).addFunction((byte) 0, EMPTY_BYTES, "41 00 0B").addFunction((byte) 0, - EMPTY_BYTES, "41 01 0B").addFunction((byte) 0, EMPTY_BYTES, "41 02 0B").addFunction((byte) 0, EMPTY_BYTES, mainHexCode).addFunctionExport( - (byte) 3, "main"); + return newBuilder().addType(EMPTY_INTS, new int[]{WasmType.I32_TYPE}).addTable(3, 3, WasmType.FUNCREF_TYPE).addFunction(0, EMPTY_INTS, "41 00 0B").addFunction(0, + EMPTY_INTS, "41 01 0B").addFunction(0, EMPTY_INTS, "41 02 0B").addFunction(0, EMPTY_INTS, mainHexCode).addFunctionExport( + 3, "main"); } @Test @@ -222,7 +222,7 @@ public void testTableInitDeclarativeType3() throws IOException { // i32.const 0 // // (elem declare (ref.func 3)) - final byte[] binary = getDefaultTableInitBuilder("D2 03 1A 41 00 0B").addFunction((byte) 0, EMPTY_BYTES, "41 03 0B").addElements("03 00 01 03").build(); + final byte[] binary = getDefaultTableInitBuilder("D2 03 1A 41 00 0B").addFunction(0, EMPTY_INTS, "41 03 0B").addElements("03 00 01 03").build(); runRuntimeTest(binary, instance -> { Value main = instance.getMember("main"); Value result = main.execute(); @@ -238,7 +238,7 @@ public void testTableInitDeclarativeType7() throws IOException { // i32.const 0 // // (elem declare (ref.func 3)) - final byte[] binary = getDefaultTableInitBuilder("D2 03 1A 41 00 0B").addFunction((byte) 0, EMPTY_BYTES, "41 03 0B").addElements("07 70 01 D2 03 0B").build(); + final byte[] binary = getDefaultTableInitBuilder("D2 03 1A 41 00 0B").addFunction(0, EMPTY_INTS, "41 03 0B").addElements("07 70 01 D2 03 0B").build(); runRuntimeTest(binary, instance -> { Value main = instance.getMember("main"); Value result = main.execute(); @@ -535,15 +535,14 @@ public void testMultipleTables() throws IOException { // (table 1 1 funcref) // (table 1 1 externref) // (table 1 1 exnref) - final byte[] binary = newBuilder().addTable((byte) 1, (byte) 1, WasmType.FUNCREF_TYPE).addTable((byte) 1, (byte) 1, WasmType.EXTERNREF_TYPE).addTable((byte) 1, (byte) 1, - WasmType.EXNREF_TYPE).build(); + final byte[] binary = newBuilder().addTable(1, 1, WasmType.FUNCREF_TYPE).addTable(1, 1, WasmType.EXTERNREF_TYPE).addTable(1, 1, WasmType.EXNREF_TYPE).build(); runParserTest(binary, options -> options.option("wasm.Exceptions", "true"), Context::eval); } @Test public void testTableInvalidElemType() throws IOException { // (table 1 1 i32) - final byte[] binary = newBuilder().addTable((byte) 1, (byte) 1, WasmType.I32_TYPE).build(); + final byte[] binary = newBuilder().addTable(1, 1, WasmType.I32_TYPE).build(); runParserTest(binary, (context, source) -> { try { context.eval(source); @@ -564,7 +563,7 @@ private static AbstractBinarySuite.BinaryBuilder getDefaultMemoryInitBuilder(Str // ;; main hex code // ) // ) - return newBuilder().addType(EMPTY_BYTES, new byte[]{WasmType.I32_TYPE}).addMemory((byte) 1, (byte) 1).addFunction((byte) 0, EMPTY_BYTES, mainHexCode).addFunctionExport((byte) 0, "main"); + return newBuilder().addType(EMPTY_INTS, new int[]{WasmType.I32_TYPE}).addMemory(1, 1).addFunction(0, EMPTY_INTS, mainHexCode).addFunctionExport(0, "main"); } @Test @@ -875,8 +874,8 @@ public void testGlobalWithNull() throws IOException { // (func (export "main") (type 0) // global.get 0 // ) - final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.EXTERNREF_TYPE, "D0 6F 0B").addType(EMPTY_BYTES, new byte[]{WasmType.EXTERNREF_TYPE}).addFunction((byte) 0, - EMPTY_BYTES, "23 00 0B").addFunctionExport((byte) 0, "main").build(); + final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.EXTERNREF_TYPE, "D0 6F 0B").addType(EMPTY_INTS, new int[]{WasmType.EXTERNREF_TYPE}).addFunction(0, + EMPTY_INTS, "23 00 0B").addFunctionExport(0, "main").build(); runRuntimeTest(binary, instance -> { Value main = instance.getMember("main"); Value result = main.execute(); @@ -891,8 +890,8 @@ public void testGlobalWithNullException() throws IOException { // (func (export "main") (type 0) // global.get 0 // ) - final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.EXNREF_TYPE, "D0 69 0B").addType(EMPTY_BYTES, new byte[]{WasmType.EXNREF_TYPE}).addFunction((byte) 0, - EMPTY_BYTES, "23 00 0B").addFunctionExport((byte) 0, "main").build(); + final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.EXNREF_TYPE, "D0 69 0B").addType(EMPTY_INTS, new int[]{WasmType.EXNREF_TYPE}).addFunction(0, + EMPTY_INTS, "23 00 0B").addFunctionExport(0, "main").build(); runRuntimeTest(binary, options -> options.option("wasm.Exceptions", "true"), instance -> { Value main = instance.getMember("main"); Value result = main.execute(); @@ -915,9 +914,9 @@ public void testGlobalWithFunction() throws IOException { // i32.const 0 // call_indirect 0 (type 0) // ) - final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.FUNCREF_TYPE, "D2 00 0B").addTable((byte) 1, (byte) 1, WasmType.FUNCREF_TYPE).addType(EMPTY_BYTES, - new byte[]{WasmType.I32_TYPE}).addFunction((byte) 0, EMPTY_BYTES, "41 01 0B").addFunction((byte) 0, EMPTY_BYTES, "41 00 23 00 26 00 41 00 11 00 00 0B").addFunctionExport( - (byte) 1, "main").build(); + final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.FUNCREF_TYPE, "D2 00 0B").addTable(1, 1, WasmType.FUNCREF_TYPE).addType(EMPTY_INTS, + new int[]{WasmType.I32_TYPE}).addFunction(0, EMPTY_INTS, "41 01 0B").addFunction(0, EMPTY_INTS, "41 00 23 00 26 00 41 00 11 00 00 0B").addFunctionExport( + 1, "main").build(); runRuntimeTest(binary, instance -> { Value main = instance.getMember("main"); Value result = main.execute(); @@ -956,13 +955,13 @@ public void testMultiValueWithReferenceTypes() throws IOException { // table.set 1 // ) // (elem (table 0) (ref.func 0) (ref.func 1)) - final byte[] binary = newBuilder().addTable((byte) 2, (byte) 2, WasmType.FUNCREF_TYPE).addTable((byte) 2, (byte) 2, WasmType.EXTERNREF_TYPE).addType(EMPTY_BYTES, - new byte[]{WasmType.I32_TYPE}).addType(new byte[]{WasmType.I32_TYPE, WasmType.FUNCREF_TYPE}, new byte[]{WasmType.EXTERNREF_TYPE, WasmType.FUNCREF_TYPE}).addType(EMPTY_BYTES, - new byte[]{WasmType.EXTERNREF_TYPE, WasmType.FUNCREF_TYPE}).addType(new byte[]{WasmType.I32_TYPE, WasmType.EXTERNREF_TYPE}, EMPTY_BYTES).addFunction((byte) 0, - EMPTY_BYTES, "41 01 0B").addFunction((byte) 0, EMPTY_BYTES, "41 02 0B").addFunction((byte) 1, EMPTY_BYTES, "20 00 25 01 20 01 0B").addFunction( - (byte) 2, EMPTY_BYTES, "41 01 41 00 25 00 10 02 0B").addFunction((byte) 3, EMPTY_BYTES, - "20 00 20 01 26 01 0B").addFunctionExport((byte) 3, - "main").addFunctionExport((byte) 4, "setRef").addElements("00 41 00 0B 02 00 01").build(); + final byte[] binary = newBuilder().addTable(2, 2, WasmType.FUNCREF_TYPE).addTable(2, 2, WasmType.EXTERNREF_TYPE).addType(EMPTY_INTS, + new int[]{WasmType.I32_TYPE}).addType(new int[]{WasmType.I32_TYPE, WasmType.FUNCREF_TYPE}, new int[]{WasmType.EXTERNREF_TYPE, WasmType.FUNCREF_TYPE}).addType(EMPTY_INTS, + new int[]{WasmType.EXTERNREF_TYPE, WasmType.FUNCREF_TYPE}).addType(new int[]{WasmType.I32_TYPE, WasmType.EXTERNREF_TYPE}, EMPTY_INTS).addFunction(0, + EMPTY_INTS, "41 01 0B").addFunction(0, EMPTY_INTS, "41 02 0B").addFunction(1, EMPTY_INTS, "20 00 25 01 20 01 0B").addFunction( + 2, EMPTY_INTS, "41 01 41 00 25 00 10 02 0B").addFunction(3, EMPTY_INTS, + "20 00 20 01 26 01 0B").addFunctionExport(3, + "main").addFunctionExport(4, "setRef").addElements("00 41 00 0B 02 00 01").build(); runRuntimeTest(binary, instance -> { Value setRef = instance.getMember("setRef"); setRef.execute(0, "foo"); diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java index 1e4c35c44603..33de86503d32 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java @@ -167,7 +167,7 @@ public static Collection data() { Failure.Type.MALFORMED), binaryCase( "Table - import with invalid elemtype", - "Invalid element type for table import: 0x6F should = 0x70", + "Invalid element type for table import: -17 should = -16", // (import "a" "b" (table 0 1 externref)) "00 61 73 6D 01 00 00 00 02 09 01 01 61 01 62 01 6F 00 01", Failure.Type.MALFORMED), @@ -220,13 +220,13 @@ public static Collection data() { // The type `C.types[x]` must be defined in the context. stringCase( "Function - invalid type index", - "unknown type: 1 should be < 1", + "Function type variable 1 out of range. (max 0)", "(type (func (result i32))) (func (export \"f\") (type 1))", Failure.Type.INVALID), stringCase( "Function - invalid type index", - "unknown type: 4294967254 should be < 1", - "(type (func (result i32))) (func (export \"f\") (type 4294967254))", + "Function type variable 1073741823 out of range. (max 0)", + "(type (func (result i32))) (func (export \"f\") (type 1073741823))", Failure.Type.INVALID), // Under the context `C'`, the expression `express` must be valid with type diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Assert.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Assert.java index 55f94fc23eed..9d8e5aaa8fe5 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Assert.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Assert.java @@ -199,12 +199,6 @@ public static void assertTrue(boolean condition, String message, Failure failure } } - public static void assertFunctionTypeEquals(SymbolTable.FunctionType t1, SymbolTable.FunctionType t2, Failure failure) throws WasmException { - if (!t1.equals(t2)) { - fail(failure, "%s: %s should = %s", failure.name, t1, t2); - } - } - @TruffleBoundary public static RuntimeException fail(Failure failure, String format, Object... args) throws WasmException { throw WasmException.format(failure, format, args); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index c2609ef80789..7790fc7ee553 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -40,6 +40,7 @@ */ package org.graalvm.wasm; +import static java.lang.Integer.compareUnsigned; import static org.graalvm.wasm.Assert.assertByteEqual; import static org.graalvm.wasm.Assert.assertIntEqual; import static org.graalvm.wasm.Assert.assertIntLessOrEqual; @@ -48,17 +49,25 @@ import static org.graalvm.wasm.Assert.assertUnsignedIntLessOrEqual; import static org.graalvm.wasm.Assert.assertUnsignedLongLessOrEqual; import static org.graalvm.wasm.Assert.fail; +import static org.graalvm.wasm.WasmType.BOT; import static org.graalvm.wasm.WasmType.EXNREF_TYPE; +import static org.graalvm.wasm.WasmType.EXN_HEAPTYPE; import static org.graalvm.wasm.WasmType.EXTERNREF_TYPE; +import static org.graalvm.wasm.WasmType.EXTERN_HEAPTYPE; import static org.graalvm.wasm.WasmType.F32_TYPE; import static org.graalvm.wasm.WasmType.F64_TYPE; import static org.graalvm.wasm.WasmType.FUNCREF_TYPE; +import static org.graalvm.wasm.WasmType.FUNC_HEAPTYPE; import static org.graalvm.wasm.WasmType.I32_TYPE; import static org.graalvm.wasm.WasmType.I64_TYPE; -import static org.graalvm.wasm.WasmType.NULL_TYPE; +import static org.graalvm.wasm.WasmType.REF_NULL_TYPE_HEADER; +import static org.graalvm.wasm.WasmType.REF_TYPE_HEADER; import static org.graalvm.wasm.WasmType.V128_TYPE; -import static org.graalvm.wasm.WasmType.VOID_TYPE; +import static org.graalvm.wasm.WasmType.VOID_BLOCK_TYPE; import static org.graalvm.wasm.constants.Bytecode.vectorOpcodeToBytecode; +import static org.graalvm.wasm.constants.BytecodeBitEncoding.ELEM_ITEM_REF_FUNC_ENTRY_PREFIX; +import static org.graalvm.wasm.constants.BytecodeBitEncoding.ELEM_ITEM_GLOBAL_GET_ENTRY_PREFIX; +import static org.graalvm.wasm.constants.BytecodeBitEncoding.ELEM_ITEM_REF_NULL_ENTRY_PREFIX; import static org.graalvm.wasm.constants.Sizes.MAX_MEMORY_64_DECLARATION_SIZE; import static org.graalvm.wasm.constants.Sizes.MAX_MEMORY_DECLARATION_SIZE; import static org.graalvm.wasm.constants.Sizes.MAX_TABLE_DECLARATION_SIZE; @@ -76,7 +85,7 @@ import org.graalvm.collections.Pair; import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.api.Vector128Shape; -import org.graalvm.wasm.collection.ByteArrayList; +import org.graalvm.wasm.collection.IntArrayList; import org.graalvm.wasm.constants.Bytecode; import org.graalvm.wasm.constants.BytecodeBitEncoding; import org.graalvm.wasm.constants.ExceptionHandlerType; @@ -98,6 +107,7 @@ import org.graalvm.wasm.parser.ir.CodeEntry; import org.graalvm.wasm.parser.validation.ExceptionHandler; import org.graalvm.wasm.parser.validation.ParserState; +import org.graalvm.wasm.parser.validation.ValidationErrors; import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; @@ -111,6 +121,8 @@ public class BinaryParser extends BinaryStreamParser { private static final int MAGIC = 0x6d736100; private static final int VERSION = 0x00000001; + private static final int[] EMPTY_TYPES = new int[0]; + private final WasmModule module; private final WasmContext wasmContext; private final int[] multiResult; @@ -124,6 +136,7 @@ public class BinaryParser extends BinaryStreamParser { private final boolean threads; private final boolean simd; private final boolean exceptions; + private final boolean typedFunctionReferences; @TruffleBoundary public BinaryParser(WasmModule module, WasmContext context, byte[] data) { @@ -140,6 +153,7 @@ public BinaryParser(WasmModule module, WasmContext context, byte[] data) { this.threads = context.getContextOptions().supportThreads(); this.simd = context.getContextOptions().supportSIMD(); this.exceptions = context.getContextOptions().supportExceptions(); + this.typedFunctionReferences = context.getContextOptions().supportTypedFunctionReferences(); } @TruffleBoundary @@ -180,6 +194,7 @@ private void readSymbolSections() { final int size = readLength(); final int startOffset = offset; + final int endOffset = startOffset + size; switch (sectionID) { case Section.CUSTOM: readCustomSection(size, customData); @@ -194,7 +209,7 @@ private void readSymbolSections() { readFunctionSection(); break; case Section.TABLE: - readTableSection(); + readTableSection(endOffset); break; case Section.MEMORY: readMemorySection(); @@ -203,7 +218,7 @@ private void readSymbolSections() { readTagSection(); break; case Section.GLOBAL: - readGlobalSection(); + readGlobalSection(endOffset); break; case Section.EXPORT: readExportSection(); @@ -212,7 +227,7 @@ private void readSymbolSections() { readStartSection(); break; case Section.ELEMENT: - readElementSection(bytecode); + readElementSection(bytecode, endOffset); break; case Section.DATA_COUNT: if (bulkMemoryAndRefTypes) { @@ -228,7 +243,7 @@ private void readSymbolSections() { break; case Section.DATA: dataSectionPresent = true; - readDataSection(bytecode); + readDataSection(bytecode, endOffset); break; } assertIntEqual(offset - startOffset, size, Failure.SECTION_SIZE_MISMATCH, "Declared section (0x%02X) size is incorrect", sectionID); @@ -426,9 +441,9 @@ private void readImportSection() { break; } case ImportIdentifier.TABLE: { - final byte elemType = readRefType(exceptions); + final int elemType = readRefType(); if (!bulkMemoryAndRefTypes) { - assertByteEqual(elemType, FUNCREF_TYPE, "Invalid element type for table import", Failure.UNSPECIFIED_MALFORMED); + assertIntEqual(elemType, FUNCREF_TYPE, Failure.UNSPECIFIED_MALFORMED, "Invalid element type for table import"); } readTableLimits(multiResult); final int tableIndex = module.tableCount(); @@ -444,7 +459,7 @@ private void readImportSection() { break; } case ImportIdentifier.GLOBAL: { - byte type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + int type = readValueType(); byte mutability = readMutability(); int globalIndex = module.symbolTable().numGlobals(); module.symbolTable().importGlobal(moduleName, memberName, globalIndex, type, mutability); @@ -472,20 +487,37 @@ private void readFunctionSection() { module.limits().checkFunctionCount(functionCount); for (int functionIndex = 0; functionIndex != functionCount; functionIndex++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); - int functionTypeIndex = readUnsignedInt32(); + int functionTypeIndex = readTypeIndex(); module.symbolTable().declareFunction(functionTypeIndex); } } - private void readTableSection() { + private void readTableSection(int endOffset) { final int tableCount = readLength(); final int startingTableIndex = module.tableCount(); module.limits().checkTableCount(startingTableIndex + tableCount); for (int tableIndex = startingTableIndex; tableIndex != startingTableIndex + tableCount; tableIndex++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); - final byte elemType = readRefType(exceptions); - readTableLimits(multiResult); - module.symbolTable().allocateTable(tableIndex, multiResult[0], multiResult[1], elemType, bulkMemoryAndRefTypes); + final int elemType; + final Object initValue; + final byte[] initBytecode; + if (peek1(data, offset) == 0x40 && peek1(data, offset + 1) == 0x00) { + Assert.assertTrue(typedFunctionReferences, Failure.MALFORMED_VALUE_TYPE); + offset += 2; + elemType = readRefType(); + readTableLimits(multiResult); + Pair initExpression = readConstantExpression(elemType, endOffset); + initValue = initExpression.getLeft(); + // Drop the initializer bytecode if we can eval the initializer during parsing + initBytecode = initValue == null ? initExpression.getRight() : null; + } else { + elemType = readRefType(); + readTableLimits(multiResult); + initValue = null; + initBytecode = null; + Assert.assertTrue(WasmType.isNullable(elemType), "uninitialized table of non-nullable element type", Failure.TYPE_MISMATCH); + } + module.symbolTable().declareTable(tableIndex, multiResult[0], multiResult[1], elemType, initBytecode, initValue, bulkMemoryAndRefTypes); } } @@ -516,7 +548,7 @@ private void readCodeSection(RuntimeBytecodeGen bytecode, BytecodeGen functionDe final int codeEntrySize = readUnsignedInt32(); final int startOffset = offset; module.limits().checkFunctionSize(codeEntrySize); - final ByteArrayList locals = readCodeEntryLocals(); + final IntArrayList locals = readCodeEntryLocals(); final int localCount = locals.size() + module.function(importedFunctionCount + entryIndex).paramCount(); module.limits().checkLocalCount(localCount); // Store the function start offset, instruction start offset, and function end offset. @@ -529,33 +561,29 @@ private void readCodeSection(RuntimeBytecodeGen bytecode, BytecodeGen functionDe module.setCodeEntries(codeEntries); } - private CodeEntry readCodeEntry(int functionIndex, ByteArrayList locals, int endOffset, boolean hasNextFunction, RuntimeBytecodeGen bytecode, int codeEntryIndex) { + private CodeEntry readCodeEntry(int functionIndex, IntArrayList locals, int endOffset, boolean hasNextFunction, RuntimeBytecodeGen bytecode, int codeEntryIndex) { final WasmFunction function = module.symbolTable().function(functionIndex); int paramCount = function.paramCount(); - byte[] localTypes = new byte[function.paramCount() + locals.size()]; + int[] localTypes = new int[function.paramCount() + locals.size()]; for (int index = 0; index != paramCount; index++) { localTypes[index] = function.paramTypeAt(index); } for (int index = 0; index != locals.size(); index++) { localTypes[index + paramCount] = locals.get(index); } - byte[] resultTypes = new byte[function.resultCount()]; - for (int index = 0; index != resultTypes.length; index++) { - resultTypes[index] = function.resultTypeAt(index); - } - return readFunction(functionIndex, localTypes, resultTypes, endOffset, hasNextFunction, bytecode, codeEntryIndex, null); + return readFunction(functionIndex, localTypes, endOffset, hasNextFunction, bytecode, codeEntryIndex, null); } - private ByteArrayList readCodeEntryLocals() { + private IntArrayList readCodeEntryLocals() { final int localsGroupCount = readLength(); - final ByteArrayList localTypes = new ByteArrayList(); + final IntArrayList localTypes = new IntArrayList(); int localsLength = 0; for (int localGroup = 0; localGroup != localsGroupCount; localGroup++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); final int groupLength = readUnsignedInt32(); localsLength += groupLength; module.limits().checkLocalCount(localsLength); - final byte t = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + final int t = readValueType(); for (int i = 0; i != groupLength; ++i) { localTypes.add(t); } @@ -563,27 +591,26 @@ private ByteArrayList readCodeEntryLocals() { return localTypes; } - private byte[] extractBlockParamTypes(int typeIndex) { + private int[] extractBlockParamTypes(int typeIndex) { int paramCount = module.functionTypeParamCount(typeIndex); - byte[] params = new byte[paramCount]; + int[] params = new int[paramCount]; for (int i = 0; i < paramCount; i++) { params[i] = module.functionTypeParamTypeAt(typeIndex, i); } return params; } - private byte[] extractBlockResultTypes(int typeIndex) { + private int[] extractBlockResultTypes(int typeIndex) { int resultCount = module.functionTypeResultCount(typeIndex); - byte[] results = new byte[resultCount]; + int[] results = new int[resultCount]; for (int i = 0; i < resultCount; i++) { results[i] = module.functionTypeResultTypeAt(typeIndex, i); } return results; } - private static byte[] encapsulateResultType(int type) { + private static int[] encapsulateResultType(int type) { return switch (type) { - case VOID_TYPE -> WasmType.VOID_TYPE_ARRAY; case I32_TYPE -> WasmType.I32_TYPE_ARRAY; case I64_TYPE -> WasmType.I64_TYPE_ARRAY; case F32_TYPE -> WasmType.F32_TYPE_ARRAY; @@ -592,16 +619,18 @@ private static byte[] encapsulateResultType(int type) { case FUNCREF_TYPE -> WasmType.FUNCREF_TYPE_ARRAY; case EXTERNREF_TYPE -> WasmType.EXTERNREF_TYPE_ARRAY; case EXNREF_TYPE -> WasmType.EXNREF_TYPE_ARRAY; - default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL); + default -> new int[]{type}; }; } - private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTypes, int sourceCodeEndOffset, boolean hasNextFunction, RuntimeBytecodeGen bytecode, + private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEndOffset, boolean hasNextFunction, RuntimeBytecodeGen bytecode, int codeEntryIndex, EconomicMap offsetToLineIndexMap) { - final ParserState state = new ParserState(bytecode); + final ParserState state = new ParserState(bytecode, module); final ArrayList callNodes = new ArrayList<>(); final int bytecodeStartOffset = bytecode.location(); - state.enterFunction(resultTypes); + int[] paramTypes = module.function(functionIndex).paramTypes(); + int[] resultTypes = module.function(functionIndex).resultTypes(); + state.enterFunction(paramTypes, resultTypes, locals); int opcode; end: while (offset < sourceCodeEndOffset) { @@ -623,20 +652,26 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy state.addInstruction(Bytecode.NOP); break; case Instructions.BLOCK: { - final byte[] blockParamTypes; - final byte[] blockResultTypes; - readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); + final int[] blockParamTypes; + final int[] blockResultTypes; + readBlockType(multiResult); // Extract value based on result arity. - if (multiResult[1] == SINGLE_RESULT_VALUE) { - blockParamTypes = WasmType.VOID_TYPE_ARRAY; - blockResultTypes = encapsulateResultType(multiResult[0]); - } else if (multiValue) { - int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); - blockParamTypes = extractBlockParamTypes(typeIndex); - blockResultTypes = extractBlockResultTypes(typeIndex); - } else { - throw WasmException.create(Failure.DISABLED_MULTI_VALUE); + switch (multiResult[1]) { + case BLOCK_TYPE_VOID -> { + blockParamTypes = WasmType.VOID_TYPE_ARRAY; + blockResultTypes = WasmType.VOID_TYPE_ARRAY; + } + case BLOCK_TYPE_VALTYPE -> { + blockParamTypes = WasmType.VOID_TYPE_ARRAY; + blockResultTypes = encapsulateResultType(multiResult[0]); + } + case BLOCK_TYPE_TYPE_INDEX -> { + int typeIndex = multiResult[0]; + checkFunctionTypeExists(typeIndex); + blockParamTypes = extractBlockParamTypes(typeIndex); + blockResultTypes = extractBlockResultTypes(typeIndex); + } + default -> throw WasmException.create(Failure.DISABLED_MULTI_VALUE); } state.popAll(blockParamTypes); state.enterBlock(blockParamTypes, blockResultTypes); @@ -644,48 +679,60 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy } case Instructions.LOOP: { // Jumps are targeting the loop instruction for OSR. - final byte[] loopParamTypes; - final byte[] loopResultTypes; - readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); + final int[] loopParamTypes; + final int[] loopResultTypes; + readBlockType(multiResult); // Extract value based on result arity. - if (multiResult[1] == SINGLE_RESULT_VALUE) { - loopParamTypes = WasmType.VOID_TYPE_ARRAY; - loopResultTypes = encapsulateResultType(multiResult[0]); - } else if (multiValue) { - int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); - loopParamTypes = extractBlockParamTypes(typeIndex); - loopResultTypes = extractBlockResultTypes(typeIndex); - } else { - throw WasmException.create(Failure.DISABLED_MULTI_VALUE); + switch (multiResult[1]) { + case BLOCK_TYPE_VOID -> { + loopParamTypes = WasmType.VOID_TYPE_ARRAY; + loopResultTypes = WasmType.VOID_TYPE_ARRAY; + } + case BLOCK_TYPE_VALTYPE -> { + loopParamTypes = WasmType.VOID_TYPE_ARRAY; + loopResultTypes = encapsulateResultType(multiResult[0]); + } + case BLOCK_TYPE_TYPE_INDEX -> { + int typeIndex = multiResult[0]; + checkFunctionTypeExists(typeIndex); + loopParamTypes = extractBlockParamTypes(typeIndex); + loopResultTypes = extractBlockResultTypes(typeIndex); + } + default -> throw WasmException.create(Failure.DISABLED_MULTI_VALUE); } state.popAll(loopParamTypes); state.enterLoop(loopParamTypes, loopResultTypes); break; } case Instructions.IF: { - state.popChecked(I32_TYPE); // condition - final byte[] ifParamTypes; - final byte[] ifResultTypes; - readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); + final int[] ifParamTypes; + final int[] ifResultTypes; + readBlockType(multiResult); // Extract value based on result arity. - if (multiResult[1] == SINGLE_RESULT_VALUE) { - ifParamTypes = WasmType.VOID_TYPE_ARRAY; - ifResultTypes = encapsulateResultType(multiResult[0]); - } else if (multiValue) { - int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); - ifParamTypes = extractBlockParamTypes(typeIndex); - ifResultTypes = extractBlockResultTypes(typeIndex); - } else { - throw WasmException.create(Failure.DISABLED_MULTI_VALUE); + switch (multiResult[1]) { + case BLOCK_TYPE_VOID -> { + ifParamTypes = WasmType.VOID_TYPE_ARRAY; + ifResultTypes = WasmType.VOID_TYPE_ARRAY; + } + case BLOCK_TYPE_VALTYPE -> { + ifParamTypes = WasmType.VOID_TYPE_ARRAY; + ifResultTypes = encapsulateResultType(multiResult[0]); + } + case BLOCK_TYPE_TYPE_INDEX -> { + int typeIndex = multiResult[0]; + checkFunctionTypeExists(typeIndex); + ifParamTypes = extractBlockParamTypes(typeIndex); + ifResultTypes = extractBlockResultTypes(typeIndex); + } + default -> throw WasmException.create(Failure.DISABLED_MULTI_VALUE); } + state.popChecked(I32_TYPE); // condition state.popAll(ifParamTypes); state.enterIf(ifParamTypes, ifResultTypes); break; } case Instructions.END: { - final byte[] endResultTypes = state.exit(multiValue); + final int[] endResultTypes = state.exit(multiValue); state.pushAll(endResultTypes); if (state.controlStackSize() == 0) { /* @@ -751,7 +798,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy // Pop parameters final WasmFunction function = module.function(callFunctionIndex); - byte[] params = new byte[function.paramCount()]; + int[] params = new int[function.paramCount()]; for (int i = function.paramCount() - 1; i >= 0; --i) { params[i] = function.paramTypeAt(i); } @@ -761,18 +808,17 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy if (!multiValue) { assertIntLessOrEqual(function.resultCount(), 1, Failure.INVALID_RESULT_ARITY); } - state.pushAll(function.type().resultTypes()); + state.pushAll(function.resultTypes()); state.addCall(callNodes.size(), callFunctionIndex); callNodes.add(new CallNode(bytecode.location(), callFunctionIndex)); break; } case Instructions.CALL_INDIRECT: { - final int expectedFunctionTypeIndex = readUnsignedInt32(); + final int expectedFunctionTypeIndex = readTypeIndex(); final int tableIndex = readTableIndex(); // Pop the function index to call state.popChecked(I32_TYPE); - state.checkFunctionTypeExists(expectedFunctionTypeIndex, module.typeCount()); - assertByteEqual(FUNCREF_TYPE, module.tableElementType(tableIndex), Failure.TYPE_MISMATCH); + Assert.assertTrue(module.matchesType(FUNCREF_TYPE, module.tableElementType(tableIndex)), Failure.TYPE_MISMATCH); // Pop parameters for (int i = module.functionTypeParamCount(expectedFunctionTypeIndex) - 1; i >= 0; --i) { @@ -783,7 +829,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy if (!multiValue) { assertIntLessOrEqual(resultCount, 1, Failure.INVALID_RESULT_ARITY); } - byte[] callResultTypes = new byte[resultCount]; + int[] callResultTypes = new int[resultCount]; for (int i = 0; i < resultCount; i++) { callResultTypes[i] = module.functionTypeResultTypeAt(expectedFunctionTypeIndex, i); } @@ -793,7 +839,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy break; } case Instructions.DROP: - final byte type = state.pop(); + final int type = state.pop(); if (WasmType.isNumberType(type)) { state.addInstruction(Bytecode.DROP); } else { @@ -802,11 +848,11 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy break; case Instructions.SELECT: { state.popChecked(I32_TYPE); // condition - final byte t1 = state.pop(); // first operand - final byte t2 = state.pop(); // second operand + final int t1 = state.pop(); // first operand + final int t2 = state.pop(); // second operand assertTrue((WasmType.isNumberType(t1) || WasmType.isVectorType(t1)) && (WasmType.isNumberType(t2) || WasmType.isVectorType(t2)), Failure.TYPE_MISMATCH); - assertTrue(t1 == t2 || t1 == WasmType.UNKNOWN_TYPE || t2 == WasmType.UNKNOWN_TYPE, Failure.TYPE_MISMATCH); - final byte t = t1 == WasmType.UNKNOWN_TYPE ? t2 : t1; + assertTrue(t1 == t2 || t1 == BOT || t2 == BOT, Failure.TYPE_MISMATCH); + final int t = t1 == BOT ? t2 : t1; state.push(t); if (WasmType.isNumberType(t)) { state.addSelectInstruction(Bytecode.SELECT); @@ -819,7 +865,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy checkBulkMemoryAndRefTypesSupport(opcode); final int length = readLength(); assertIntEqual(length, 1, Failure.INVALID_RESULT_ARITY); - final byte t = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + final int t = readValueType(); state.popChecked(I32_TYPE); state.popChecked(t); state.popChecked(t); @@ -833,20 +879,26 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy } case Instructions.TRY_TABLE: { checkExceptionHandlingSupport(opcode); - final byte[] tryTableParamTypes; - final byte[] tryTableResultTypes; - readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); - - if (multiResult[1] == SINGLE_RESULT_VALUE) { - tryTableParamTypes = WasmType.VOID_TYPE_ARRAY; - tryTableResultTypes = encapsulateResultType(multiResult[0]); - } else if (multiValue) { - final int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); - tryTableParamTypes = extractBlockParamTypes(typeIndex); - tryTableResultTypes = extractBlockResultTypes(typeIndex); - } else { - throw WasmException.create(Failure.DISABLED_MULTI_VALUE); + final int[] tryTableParamTypes; + final int[] tryTableResultTypes; + readBlockType(multiResult); + // Extract value based on result arity. + switch (multiResult[1]) { + case BLOCK_TYPE_VOID -> { + tryTableParamTypes = WasmType.VOID_TYPE_ARRAY; + tryTableResultTypes = WasmType.VOID_TYPE_ARRAY; + } + case BLOCK_TYPE_VALTYPE -> { + tryTableParamTypes = WasmType.VOID_TYPE_ARRAY; + tryTableResultTypes = encapsulateResultType(multiResult[0]); + } + case BLOCK_TYPE_TYPE_INDEX -> { + int typeIndex = multiResult[0]; + checkFunctionTypeExists(typeIndex); + tryTableParamTypes = extractBlockParamTypes(typeIndex); + tryTableResultTypes = extractBlockResultTypes(typeIndex); + } + default -> throw WasmException.create(Failure.DISABLED_MULTI_VALUE); } state.popAll(tryTableParamTypes); final ExceptionHandler[] handlers = readExceptionHandlers(state); @@ -857,8 +909,8 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy checkExceptionHandlingSupport(opcode); final int tagIndex = readTagIndex(); final int typeIndex = module.tagTypeIndex(tagIndex); - final byte[] paramTypes = module.typeAt(typeIndex).paramTypes(); - state.popAll(paramTypes); + final int[] tagParamTypes = module.functionTypeParamTypesAsArray(typeIndex); + state.popAll(tagParamTypes); state.addMiscFlag(); state.addInstruction(Bytecode.THROW, tagIndex); @@ -877,7 +929,8 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy case Instructions.LOCAL_GET: { final int localIndex = readLocalIndex(); assertUnsignedIntLess(localIndex, locals.length, Failure.UNKNOWN_LOCAL); - final byte localType = locals[localIndex]; + Assert.assertTrue(state.isLocalInitialized(localIndex), Failure.UNINITIALIZED_LOCAL); + final int localType = locals[localIndex]; state.push(localType); if (WasmType.isNumberType(localType)) { state.addUnsignedInstruction(Bytecode.LOCAL_GET_U8, localIndex); @@ -889,7 +942,8 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy case Instructions.LOCAL_SET: { final int localIndex = readLocalIndex(); assertUnsignedIntLess(localIndex, locals.length, Failure.UNKNOWN_LOCAL); - final byte localType = locals[localIndex]; + state.initializeLocal(localIndex); + final int localType = locals[localIndex]; state.popChecked(localType); if (WasmType.isNumberType(localType)) { state.addUnsignedInstruction(Bytecode.LOCAL_SET_U8, localIndex); @@ -901,7 +955,8 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy case Instructions.LOCAL_TEE: { final int localIndex = readLocalIndex(); assertUnsignedIntLess(localIndex, locals.length, Failure.UNKNOWN_LOCAL); - final byte localType = locals[localIndex]; + state.initializeLocal(localIndex); + final int localType = locals[localIndex]; state.popChecked(localType); state.push(localType); if (WasmType.isNumberType(localType)) { @@ -929,18 +984,20 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy case Instructions.TABLE_GET: { checkBulkMemoryAndRefTypesSupport(opcode); final int index = readTableIndex(); - final byte elementType = module.tableElementType(index); + final int elementType = module.tableElementType(index); state.popChecked(I32_TYPE); state.push(elementType); + state.addMiscFlag(); state.addInstruction(Bytecode.TABLE_GET, index); break; } case Instructions.TABLE_SET: { checkBulkMemoryAndRefTypesSupport(opcode); final int index = readTableIndex(); - final byte elementType = module.tableElementType(index); + final int elementType = module.tableElementType(index); state.popChecked(elementType); state.popChecked(I32_TYPE); + state.addMiscFlag(); state.addInstruction(Bytecode.TABLE_SET, index); break; } @@ -1076,6 +1133,54 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy } break; } + case Instructions.CALL_REF: { + checkTypedFunctionReferencesSupport(opcode); + final int expectedFunctionTypeIndex = readTypeIndex(); + final int functionReferenceType = WasmType.withNullable(true, expectedFunctionTypeIndex); + state.popChecked(functionReferenceType); + // Pop parameters + final int paramCount = module.functionTypeParamCount(expectedFunctionTypeIndex); + for (int i = paramCount - 1; i >= 0; i--) { + state.popChecked(module.functionTypeParamTypeAt(expectedFunctionTypeIndex, i)); + } + // Push result values + final int resultCount = module.functionTypeResultCount(expectedFunctionTypeIndex); + if (!multiValue) { + assertIntLessOrEqual(resultCount, 1, Failure.INVALID_RESULT_ARITY); + } + for (int i = 0; i < resultCount; i++) { + state.push(module.functionTypeResultTypeAt(expectedFunctionTypeIndex, i)); + } + state.addRefCall(callNodes.size(), expectedFunctionTypeIndex); + callNodes.add(new CallNode(bytecode.location())); + break; + } + case Instructions.REF_AS_NON_NULL: { + checkTypedFunctionReferencesSupport(opcode); + final int referenceType = state.popReferenceTypeChecked(); + final int nonNullReferenceType = WasmType.withNullable(false, referenceType); + state.push(nonNullReferenceType); + state.addMiscFlag(); + state.addInstruction(Bytecode.REF_AS_NON_NULL); + break; + } + case Instructions.BR_ON_NULL: { + checkTypedFunctionReferencesSupport(opcode); + final int branchLabel = readTargetOffset(); + final int referenceType = state.popReferenceTypeChecked(); + state.addBranchOnNull(branchLabel); + final int nonNullReferenceType = WasmType.withNullable(false, referenceType); + state.push(nonNullReferenceType); + break; + } + case Instructions.BR_ON_NON_NULL: { + checkTypedFunctionReferencesSupport(opcode); + final int branchLabel = readTargetOffset(); + final int referenceType = state.popReferenceTypeChecked(); + final int nonNullReferenceType = WasmType.withNullable(false, referenceType); + state.addBranchOnNonNull(branchLabel, nonNullReferenceType); + break; + } default: readNumericInstructions(state, opcode); break; @@ -1116,14 +1221,14 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy final int functionEndOffset = bytecode.location(); bytecode.addCodeEntry(functionIndex, state.maxStackSize(), bytecodeEndOffset - bytecodeStartOffset, locals.length, resultTypes.length); - for (byte local : locals) { - bytecode.addByte(local); + for (int local : locals) { + bytecode.addType(local); } if (locals.length != 0) { bytecode.addByte((byte) 0); } - for (byte result : resultTypes) { - bytecode.addByte(result); + for (int result : resultTypes) { + bytecode.addType(result); } if (resultTypes.length != 0) { bytecode.addByte((byte) 0); @@ -1530,7 +1635,7 @@ private void readNumericInstructions(ParserState state, int opcode) { final int elementIndex = readUnsignedInt32(); final int tableIndex = readTableIndex(); module.checkElemIndex(elementIndex); - final byte elementType = module.tableElementType(tableIndex); + final int elementType = module.tableElementType(tableIndex); module.checkElemType(elementIndex, elementType); state.popChecked(I32_TYPE); state.popChecked(I32_TYPE); @@ -1550,10 +1655,10 @@ private void readNumericInstructions(ParserState state, int opcode) { case Instructions.TABLE_COPY: checkBulkMemoryAndRefTypesSupport(miscOpcode); final int destinationTableIndex = readTableIndex(); - final byte destinationElementType = module.tableElementType(destinationTableIndex); + final int destinationElementType = module.tableElementType(destinationTableIndex); final int sourceTableIndex = readTableIndex(); - final byte sourceElementType = module.tableElementType(sourceTableIndex); - assertByteEqual(sourceElementType, destinationElementType, Failure.TYPE_MISMATCH); + final int sourceElementType = module.tableElementType(sourceTableIndex); + Assert.assertTrue(module.matchesType(destinationElementType, sourceElementType), Failure.TYPE_MISMATCH); state.popChecked(I32_TYPE); state.popChecked(I32_TYPE); state.popChecked(I32_TYPE); @@ -1571,7 +1676,7 @@ private void readNumericInstructions(ParserState state, int opcode) { case Instructions.TABLE_GROW: { checkBulkMemoryAndRefTypesSupport(miscOpcode); final int tableIndex = readTableIndex(); - final byte elementType = module.tableElementType(tableIndex); + final int elementType = module.tableElementType(tableIndex); state.popChecked(I32_TYPE); state.popChecked(elementType); state.push(I32_TYPE); @@ -1582,7 +1687,7 @@ private void readNumericInstructions(ParserState state, int opcode) { case Instructions.TABLE_FILL: { checkBulkMemoryAndRefTypesSupport(miscOpcode); final int tableIndex = readTableIndex(); - final byte elementType = module.tableElementType(tableIndex); + final int elementType = module.tableElementType(tableIndex); state.popChecked(I32_TYPE); state.popChecked(elementType); state.popChecked(I32_TYPE); @@ -1611,8 +1716,9 @@ private void readNumericInstructions(ParserState state, int opcode) { break; case Instructions.REF_NULL: checkBulkMemoryAndRefTypesSupport(opcode); - final byte type = readRefType(exceptions); - state.push(type); + final int heapType = readHeapType(); + final int nullableReferenceType = WasmType.withNullable(true, heapType); + state.push(nullableReferenceType); state.addInstruction(Bytecode.REF_NULL); break; case Instructions.REF_IS_NULL: @@ -1625,7 +1731,8 @@ private void readNumericInstructions(ParserState state, int opcode) { checkBulkMemoryAndRefTypesSupport(opcode); final int functionIndex = readDeclaredFunctionIndex(); module.checkFunctionReference(functionIndex); - state.push(FUNCREF_TYPE); + final int functionReferenceType = WasmType.withNullable(false, module.function(functionIndex).typeIndex()); + state.push(functionReferenceType); state.addInstruction(Bytecode.REF_FUNC, functionIndex); break; case Instructions.ATOMIC: @@ -2341,11 +2448,11 @@ private void readNumericInstructions(ParserState state, int opcode) { state.addInstruction(vectorOpcodeToBytecode(vectorOpcode)); break; default: - fail(Failure.UNSPECIFIED_MALFORMED, "Unknown opcode: 0xFD 0x%02x", vectorOpcode); + fail(Failure.ILLEGAL_OPCODE, "Unknown opcode: 0xFD 0x%02x", vectorOpcode); } break; default: - fail(Failure.UNSPECIFIED_MALFORMED, "Unknown opcode: 0x%02x", opcode); + fail(Failure.ILLEGAL_OPCODE, "Unknown opcode: 0x%02x", opcode); break; } } @@ -2384,7 +2491,11 @@ private void checkExceptionHandlingSupport(int opcode) { checkContextOption(wasmContext.getContextOptions().supportExceptions(), "Exception handling is not enabled (opcode: 0x%02x)", opcode); } - private void store(ParserState state, byte type, int n, long[] result) { + private void checkTypedFunctionReferencesSupport(int opcode) { + checkContextOption(wasmContext.getContextOptions().supportTypedFunctionReferences(), "Typed function references are not enabled (opcode: 0x%02x)", opcode); + } + + private void store(ParserState state, int type, int n, long[] result) { int alignHint = readAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2398,7 +2509,7 @@ private void store(ParserState state, byte type, int n, long[] result) { result[1] = memoryOffset; } - private void load(ParserState state, byte type, int n, long[] result) { + private void load(ParserState state, int type, int n, long[] result) { final int alignHint = readAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2412,7 +2523,7 @@ private void load(ParserState state, byte type, int n, long[] result) { result[1] = memoryOffset; } - private void atomicStore(ParserState state, byte type, int n, long[] result) { + private void atomicStore(ParserState state, int type, int n, long[] result) { int alignHint = readAtomicAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2426,7 +2537,7 @@ private void atomicStore(ParserState state, byte type, int n, long[] result) { result[1] = memoryOffset; } - private void atomicLoad(ParserState state, byte type, int n, long[] result) { + private void atomicLoad(ParserState state, int type, int n, long[] result) { final int alignHint = readAtomicAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2440,7 +2551,7 @@ private void atomicLoad(ParserState state, byte type, int n, long[] result) { result[1] = memoryOffset; } - private void atomicReadModifyWrite(ParserState state, byte type, int n, long[] result) { + private void atomicReadModifyWrite(ParserState state, int type, int n, long[] result) { final int alignHint = readAtomicAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2455,7 +2566,7 @@ private void atomicReadModifyWrite(ParserState state, byte type, int n, long[] r result[1] = memoryOffset; } - private void atomicCompareExchange(ParserState state, byte type, int n, long[] result) { + private void atomicCompareExchange(ParserState state, int type, int n, long[] result) { final int alignHint = readAtomicAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2486,7 +2597,7 @@ private void atomicNotify(ParserState state, long[] result) { result[1] = memoryOffset; } - private void atomicWait(ParserState state, byte type, int n, long[] result) { + private void atomicWait(ParserState state, int type, int n, long[] result) { final int alignHint = readAtomicAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2514,7 +2625,7 @@ private ExceptionHandler[] readExceptionHandlers(ParserState state) { final int label = readUnsignedInt32(); assertUnsignedIntLess(label, state.controlStackSize(), Failure.INVALID_CATCH_CLAUSE_LABEL); final int typeIndex = module.tagTypeIndex(tag); - final byte[] paramTypes = module.typeAt(typeIndex).paramTypes(); + final int[] paramTypes = module.functionTypeParamTypesAsArray(typeIndex); handlers[i] = state.enterCatchClause(opcode, tag, label); state.pushAll(paramTypes); } @@ -2523,7 +2634,7 @@ private ExceptionHandler[] readExceptionHandlers(ParserState state) { final int label = readUnsignedInt32(); assertUnsignedIntLess(label, state.controlStackSize(), Failure.INVALID_CATCH_CLAUSE_LABEL); final int typeIndex = module.tagTypeIndex(tag); - final byte[] paramTypes = module.typeAt(typeIndex).paramTypes(); + final int[] paramTypes = module.functionTypeParamTypesAsArray(typeIndex); handlers[i] = state.enterCatchClause(opcode, tag, label); state.pushAll(paramTypes); state.push(EXNREF_TYPE); @@ -2544,11 +2655,11 @@ private ExceptionHandler[] readExceptionHandlers(ParserState state) { return handlers; } - private Pair readOffsetExpression() { + private Pair readOffsetExpression(int endOffset) { // Table offset expression must be a constant expression with result type i32. // https://webassembly.github.io/spec/core/syntax/modules.html#element-segments // https://webassembly.github.io/spec/core/valid/instructions.html#constant-expressions - Pair result = readConstantExpression(I32_TYPE, true); + Pair result = readConstantExpression(I32_TYPE, endOffset); if (result.getRight() == null) { return Pair.create((int) result.getLeft(), null); } else { @@ -2556,8 +2667,8 @@ private Pair readOffsetExpression() { } } - private Pair readLongOffsetExpression() { - Pair result = readConstantExpression(I64_TYPE, true); + private Pair readLongOffsetExpression(int endOffset) { + Pair result = readConstantExpression(I64_TYPE, endOffset); if (result.getRight() == null) { return Pair.create((long) result.getLeft(), null); } else { @@ -2565,18 +2676,19 @@ private Pair readLongOffsetExpression() { } } - private Pair readConstantExpression(byte resultType, boolean onlyImportedGlobals) { + private Pair readConstantExpression(int resultType, int endOffset) { // Read the constant expression. // https://webassembly.github.io/spec/core/valid/instructions.html#constant-expressions final RuntimeBytecodeGen bytecode = new RuntimeBytecodeGen(); - final ParserState state = new ParserState(bytecode); + final ParserState state = new ParserState(bytecode, module); final List stack = new ArrayList<>(); boolean calculable = true; - state.enterFunction(new byte[]{resultType}); - int opcode; - while ((opcode = read1() & 0xFF) != Instructions.END) { + state.enterFunction(EMPTY_TYPES, new int[]{resultType}, EMPTY_TYPES); + int opcode = -1; + read_loop: while (offset < endOffset) { + opcode = read1() & 0xFF; switch (opcode) { case Instructions.I32_CONST: { final int value = readSignedInt32(); @@ -2618,8 +2730,9 @@ private Pair readConstantExpression(byte resultType, boolean onl } case Instructions.REF_NULL: checkBulkMemoryAndRefTypesSupport(opcode); - final byte type = readRefType(exceptions); - state.push(type); + final int heapType = readHeapType(); + final int nullableReferenceType = WasmType.withNullable(true, heapType); + state.push(nullableReferenceType); state.addInstruction(Bytecode.REF_NULL); if (calculable) { stack.add(WasmConstant.NULL); @@ -2629,18 +2742,13 @@ private Pair readConstantExpression(byte resultType, boolean onl checkBulkMemoryAndRefTypesSupport(opcode); final int functionIndex = readDeclaredFunctionIndex(); module.addFunctionReference(functionIndex); - state.push(FUNCREF_TYPE); + final int functionReferenceType = WasmType.withNullable(false, module.function(functionIndex).typeIndex()); + state.push(functionReferenceType); state.addInstruction(Bytecode.REF_FUNC, functionIndex); calculable = false; break; case Instructions.GLOBAL_GET: { final int index = readGlobalIndex(); - if (onlyImportedGlobals) { - // The current WebAssembly spec says constant expressions can only refer to - // imported globals. We can easily remove this restriction in the future. - assertUnsignedIntLess(index, module.symbolTable().importedGlobals().size(), Failure.UNKNOWN_GLOBAL, - "Constant expression in module '%s' refers to non-imported global %d.", module.name(), index); - } assertIntEqual(module.globalMutability(index), GlobalModifier.CONSTANT, Failure.CONSTANT_EXPRESSION_REQUIRED); state.push(module.symbolTable().globalValueType(index)); state.addUnsignedInstruction(Bytecode.GLOBAL_GET_U8, index); @@ -2708,12 +2816,15 @@ private Pair readConstantExpression(byte resultType, boolean onl break; } break; + case Instructions.END: + break read_loop; default: fail(Failure.ILLEGAL_OPCODE, "Invalid instruction for constant expression: 0x%02X", opcode); break; } } - assertIntEqual(state.valueStackSize(), 1, Failure.TYPE_MISMATCH, "Multiple results on stack at constant expression end"); + Assert.assertTrue(opcode == Instructions.END, Failure.UNEXPECTED_END); + assertIntEqual(state.valueStackSize(), 1, Failure.TYPE_MISMATCH, "Unexpected number of results on stack at constant expression end"); state.exit(multiValue); if (calculable) { return Pair.create(stack.removeLast(), null); @@ -2722,14 +2833,16 @@ private Pair readConstantExpression(byte resultType, boolean onl } } - private long[] readFunctionIndices() { + private long[] readFunctionIndices(int elemType) { final int functionIndexCount = readLength(); final long[] functionIndices = new long[functionIndexCount]; for (int index = 0; index != functionIndexCount; index++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); final int functionIndex = readDeclaredFunctionIndex(); module.addFunctionReference(functionIndex); - functionIndices[index] = ((long) FUNCREF_TYPE << 32) | functionIndex; + final int functionReferenceType = WasmType.withNullable(false, module.function(functionIndex).typeIndex()); + Assert.assertTrue(module.matchesType(elemType, functionReferenceType), Failure.TYPE_MISMATCH); + functionIndices[index] = ((long) ELEM_ITEM_REF_FUNC_ENTRY_PREFIX << 32) | functionIndex; } return functionIndices; } @@ -2741,9 +2854,9 @@ private void checkElemKind() { } } - private long[] readElemExpressions(byte elemType) { + private long[] readElemExpressions(int elemType) { final int expressionCount = readLength(); - final long[] functionIndices = new long[expressionCount]; + final long[] elements = new long[expressionCount]; for (int index = 0; index != expressionCount; index++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); int opcode = read1() & 0xFF; @@ -2765,26 +2878,24 @@ private long[] readElemExpressions(byte elemType) { throw WasmException.format(Failure.ILLEGAL_OPCODE, "Illegal opcode for constant expression: 0x%02X", opcode); } case Instructions.REF_NULL: - final byte type = readRefType(exceptions); - if (bulkMemoryAndRefTypes && type != elemType) { - fail(Failure.TYPE_MISMATCH, "Invalid ref.null type: 0x%02X", type); - } - functionIndices[index] = ((long) NULL_TYPE << 32); + final int heapType = readHeapType(); + final int nullableReferenceType = WasmType.withNullable(true, heapType); + Assert.assertTrue(module.matchesType(elemType, nullableReferenceType), "Invalid ref.null type: 0x%02X", Failure.TYPE_MISMATCH); + elements[index] = ((long) ELEM_ITEM_REF_NULL_ENTRY_PREFIX << 32); break; case Instructions.REF_FUNC: - if (elemType != FUNCREF_TYPE) { - fail(Failure.TYPE_MISMATCH, "Invalid element type: 0x%02X", FUNCREF_TYPE); - } final int functionIndex = readDeclaredFunctionIndex(); module.addFunctionReference(functionIndex); - functionIndices[index] = ((long) FUNCREF_TYPE << 32) | functionIndex; + final int functionReferenceType = WasmType.withNullable(false, module.function(functionIndex).typeIndex()); + Assert.assertTrue(module.matchesType(elemType, functionReferenceType), "Invalid element type: 0x%02X", Failure.TYPE_MISMATCH); + elements[index] = ((long) ELEM_ITEM_REF_FUNC_ENTRY_PREFIX << 32) | functionIndex; break; case Instructions.GLOBAL_GET: final int globalIndex = readGlobalIndex(); assertIntEqual(module.globalMutability(globalIndex), GlobalModifier.CONSTANT, Failure.CONSTANT_EXPRESSION_REQUIRED); - final byte valueType = module.globalValueType(globalIndex); - assertByteEqual(valueType, elemType, Failure.TYPE_MISMATCH); - functionIndices[index] = ((long) I32_TYPE << 32) | globalIndex; + final int valueType = module.globalValueType(globalIndex); + Assert.assertTrue(module.matchesType(elemType, valueType), Failure.TYPE_MISMATCH); + elements[index] = ((long) ELEM_ITEM_GLOBAL_GET_ENTRY_PREFIX << 32) | globalIndex; break; case Instructions.VECTOR: checkSIMDSupport(); @@ -2800,10 +2911,10 @@ private long[] readElemExpressions(byte elemType) { } readEnd(); } - return functionIndices; + return elements; } - private void readElementSection(RuntimeBytecodeGen bytecode) { + private void readElementSection(RuntimeBytecodeGen bytecode, int endOffset) { int elemSegmentCount = readLength(); module.limits().checkElementSegmentCount(elemSegmentCount); for (int elemSegmentIndex = 0; elemSegmentIndex != elemSegmentCount; elemSegmentIndex++) { @@ -2813,7 +2924,7 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { final byte[] currentOffsetBytecode; final long[] elements; final int tableIndex; - final byte elemType; + final int elemType; if (bulkMemoryAndRefTypes) { final int sectionType = readUnsignedInt32(); mode = sectionType & 0b001; @@ -2826,7 +2937,7 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { } else { tableIndex = 0; } - Pair offsetExpression = readOffsetExpression(); + Pair offsetExpression = readOffsetExpression(endOffset); currentOffsetAddress = offsetExpression.getLeft(); currentOffsetBytecode = offsetExpression.getRight(); } else { @@ -2837,7 +2948,7 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { } if (useExpressions) { if (useType) { - elemType = readRefType(exceptions); + elemType = readRefType(); } else { elemType = FUNCREF_TYPE; } @@ -2845,18 +2956,20 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { } else { if (useType) { checkElemKind(); + elemType = FUNCREF_TYPE; + } else { + elemType = WasmType.withNullable(false, FUNC_HEAPTYPE); } - elemType = FUNCREF_TYPE; - elements = readFunctionIndices(); + elements = readFunctionIndices(elemType); } } else { mode = SegmentMode.ACTIVE; tableIndex = readTableIndex(); - Pair offsetExpression = readOffsetExpression(); + Pair offsetExpression = readOffsetExpression(endOffset); currentOffsetAddress = offsetExpression.getLeft(); currentOffsetBytecode = offsetExpression.getRight(); - elements = readFunctionIndices(); elemType = FUNCREF_TYPE; + elements = readFunctionIndices(elemType); } // Copy the contents, or schedule a linker task for this. @@ -2880,14 +2993,14 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { for (long element : elements) { final int initType = (int) (element >> 32); switch (initType) { - case NULL_TYPE: + case ELEM_ITEM_REF_NULL_ENTRY_PREFIX: bytecode.addElemNull(); break; - case FUNCREF_TYPE: + case ELEM_ITEM_REF_FUNC_ENTRY_PREFIX: final int functionIndex = (int) element; bytecode.addElemFunctionIndex(functionIndex); break; - case I32_TYPE: + case ELEM_ITEM_GLOBAL_GET_ENTRY_PREFIX: final int globalIndex = (int) element; bytecode.addElemGlobalIndex(globalIndex); break; @@ -2965,18 +3078,18 @@ private void readTagSection() { } } - private void readGlobalSection() { + private void readGlobalSection(int endOffset) { final int globalCount = readLength(); module.limits().checkGlobalCount(globalCount); final int startingGlobalIndex = module.symbolTable().numGlobals(); for (int globalIndex = startingGlobalIndex; globalIndex != startingGlobalIndex + globalCount; globalIndex++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); - final byte type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + final int type = readValueType(); // 0x00 means const, 0x01 means var final byte mutability = readMutability(); // Global initialization expressions must be constant expressions: // https://webassembly.github.io/spec/core/valid/instructions.html#constant-expressions - Pair initExpression = readConstantExpression(type, true); + Pair initExpression = readConstantExpression(type, endOffset); final Object initValue = initExpression.getLeft(); final byte[] initBytecode = initExpression.getRight(); final boolean isInitialized = initBytecode == null; @@ -2993,7 +3106,7 @@ private void readDataCountSection(int size) { } } - private void readDataSection(RuntimeBytecodeGen bytecode) { + private void readDataSection(RuntimeBytecodeGen bytecode, int endOffset) { final int dataSegmentCount = readLength(); module.limits().checkDataSegmentCount(dataSegmentCount); if (bulkMemoryAndRefTypes) { @@ -3023,11 +3136,11 @@ private void readDataSection(RuntimeBytecodeGen bytecode) { if (mode == SegmentMode.ACTIVE) { checkMemoryIndex(memoryIndex); if (module.memoryHasIndexType64(memoryIndex)) { - Pair offsetExpression = readLongOffsetExpression(); + Pair offsetExpression = readLongOffsetExpression(endOffset); offsetAddress = offsetExpression.getLeft(); offsetBytecode = offsetExpression.getRight(); } else { - Pair offsetExpression = readOffsetExpression(); + Pair offsetExpression = readOffsetExpression(endOffset); offsetAddress = offsetExpression.getLeft(); offsetBytecode = offsetExpression.getRight(); } @@ -3044,11 +3157,11 @@ private void readDataSection(RuntimeBytecodeGen bytecode) { memoryIndex = 0; } if (module.memoryHasIndexType64(memoryIndex)) { - Pair offsetExpression = readLongOffsetExpression(); + Pair offsetExpression = readLongOffsetExpression(endOffset); offsetAddress = offsetExpression.getLeft(); offsetBytecode = offsetExpression.getRight(); } else { - Pair offsetExpression = readOffsetExpression(); + Pair offsetExpression = readOffsetExpression(endOffset); offsetAddress = offsetExpression.getLeft(); offsetBytecode = offsetExpression.getRight(); } @@ -3086,29 +3199,101 @@ private void readDataSection(RuntimeBytecodeGen bytecode) { private void readFunctionType() { int paramCount = readLength(); - long resultCountAndValue = peekUnsignedInt32AndLength(data, offset + paramCount); - int resultCount = value(resultCountAndValue); - resultCount = (resultCount == 0x40) ? 0 : resultCount; - module.limits().checkParamCount(paramCount); + int[] paramTypes = new int[paramCount]; + for (int paramIdx = 0; paramIdx < paramCount; paramIdx++) { + paramTypes[paramIdx] = readValueType(); + } + + int resultCount = readLength(); module.limits().checkResultCount(resultCount, multiValue); - int idx = module.symbolTable().allocateFunctionType(paramCount, resultCount, multiValue); - readParameterList(idx, paramCount); - offset += length(resultCountAndValue); - readResultList(idx, resultCount); - } + int[] resultTypes = new int[resultCount]; + for (int resultIdx = 0; resultIdx < resultCount; resultIdx++) { + resultTypes[resultIdx] = readValueType(); + } - private void readParameterList(int funcTypeIdx, int paramCount) { - for (int paramIdx = 0; paramIdx != paramCount; ++paramIdx) { - byte type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); - module.symbolTable().registerFunctionTypeParameterType(funcTypeIdx, paramIdx, type); + int funcTypeIdx = module.symbolTable().allocateFunctionType(paramCount, resultCount, multiValue); + for (int paramIdx = 0; paramIdx < paramCount; paramIdx++) { + module.symbolTable().registerFunctionTypeParameterType(funcTypeIdx, paramIdx, paramTypes[paramIdx]); + } + for (int resultIdx = 0; resultIdx < resultCount; resultIdx++) { + module.symbolTable().registerFunctionTypeResultType(funcTypeIdx, resultIdx, resultTypes[resultIdx]); } + module.symbolTable().finishFunctionType(funcTypeIdx); + } + + protected int readValueType() { + final int type = readSignedInt32(); + return switch (type) { + case I32_TYPE, I64_TYPE, F32_TYPE, F64_TYPE -> type; + case V128_TYPE -> { + Assert.assertTrue(simd, Failure.MALFORMED_VALUE_TYPE); + yield type; + } + case FUNCREF_TYPE, EXTERNREF_TYPE -> { + Assert.assertTrue(bulkMemoryAndRefTypes, Failure.MALFORMED_VALUE_TYPE); + yield type; + } + case EXNREF_TYPE -> { + Assert.assertTrue(exceptions, Failure.MALFORMED_VALUE_TYPE); + yield type; + } + case REF_NULL_TYPE_HEADER -> { + Assert.assertTrue(typedFunctionReferences, Failure.MALFORMED_VALUE_TYPE); + yield WasmType.withNullable(true, readHeapType()); + } + case REF_TYPE_HEADER -> { + Assert.assertTrue(typedFunctionReferences, Failure.MALFORMED_VALUE_TYPE); + yield WasmType.withNullable(false, readHeapType()); + } + default -> throw Assert.fail(Failure.MALFORMED_VALUE_TYPE, "Invalid value type: 0x%02X", type); + }; } - private void readResultList(int funcTypeIdx, int resultCount) { - for (int resultIdx = 0; resultIdx != resultCount; resultIdx++) { - byte type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); - module.symbolTable().registerFunctionTypeResultType(funcTypeIdx, resultIdx, type); + /** + * Reads the block type at the current location. The result is provided as two values. The first + * is the actual value of the block type. The second is an indicator if it is a single result + * type or a multi-value result. + * + * @param result The array used for returning the result. + * + */ + protected void readBlockType(int[] result) { + int type = readSignedInt32(); + switch (type) { + case VOID_BLOCK_TYPE -> { + result[1] = BLOCK_TYPE_VOID; + } + case I32_TYPE, I64_TYPE, F32_TYPE, F64_TYPE -> { + result[0] = type; + result[1] = BLOCK_TYPE_VALTYPE; + } + case V128_TYPE -> { + Assert.assertTrue(simd, Failure.MALFORMED_VALUE_TYPE); + result[0] = type; + result[1] = BLOCK_TYPE_VALTYPE; + } + case FUNCREF_TYPE, EXTERNREF_TYPE -> { + Assert.assertTrue(bulkMemoryAndRefTypes, Failure.MALFORMED_VALUE_TYPE); + result[0] = type; + result[1] = BLOCK_TYPE_VALTYPE; + } + case EXNREF_TYPE -> { + Assert.assertTrue(exceptions, Failure.MALFORMED_VALUE_TYPE); + result[0] = type; + result[1] = BLOCK_TYPE_VALTYPE; + } + case REF_NULL_TYPE_HEADER, REF_TYPE_HEADER -> { + boolean nullable = type == REF_NULL_TYPE_HEADER; + int heapType = readHeapType(); + result[0] = WasmType.withNullable(nullable, heapType); + result[1] = BLOCK_TYPE_VALTYPE; + } + default -> { + result[0] = type; + Assert.assertIntGreaterOrEqual(result[0], 0, Failure.MALFORMED_VALUE_TYPE); + result[1] = BLOCK_TYPE_TYPE_INDEX; + } } } @@ -3145,10 +3330,26 @@ private int readDeclaredFunctionIndex() { return index; } + /** + * Checks if the given function type is within range. + * + * @param typeIndex The function type. + * @throws WasmException If the given function type is greater or equal to the given maximum. + */ + public void checkFunctionTypeExists(int typeIndex) { + if (compareUnsigned(typeIndex, module.typeCount()) >= 0) { + if (module.typeCount() > 0) { + throw ValidationErrors.createMissingFunctionType(typeIndex, module.typeCount() - 1); + } else { + throw ValidationErrors.createMissingFunctionType(typeIndex); + } + } + } + private int readTypeIndex() { - final int result = readUnsignedInt32(); - assertUnsignedIntLess(result, module.symbolTable().typeCount(), Failure.UNKNOWN_TYPE); - return result; + final int typeIndex = readUnsignedInt32(); + checkFunctionTypeExists(typeIndex); + return typeIndex; } private int readFunctionIndex() { @@ -3198,20 +3399,42 @@ private byte readImportType() { return read1(); } - private byte readRefType(boolean allowExnType) { - final byte refType = read1(); - switch (refType) { - case FUNCREF_TYPE: - case EXTERNREF_TYPE: - break; - case EXNREF_TYPE: - assertTrue(allowExnType, Failure.MALFORMED_REFERENCE_TYPE); - break; - default: - fail(Failure.MALFORMED_REFERENCE_TYPE, "Unexpected reference type"); - break; - } - return refType; + private int readRefType() { + final int refType = readSignedInt32(); + return switch (refType) { + case FUNCREF_TYPE, EXTERNREF_TYPE -> refType; + case EXNREF_TYPE -> { + assertTrue(exceptions, Failure.MALFORMED_REFERENCE_TYPE); + yield refType; + } + case REF_NULL_TYPE_HEADER -> { + assertTrue(typedFunctionReferences, Failure.MALFORMED_REFERENCE_TYPE); + yield WasmType.withNullable(true, readHeapType()); + } + case REF_TYPE_HEADER -> { + assertTrue(typedFunctionReferences, Failure.MALFORMED_REFERENCE_TYPE); + yield WasmType.withNullable(false, readHeapType()); + } + default -> throw fail(Failure.MALFORMED_REFERENCE_TYPE, "Unexpected reference type"); + }; + } + + private int readHeapType() { + int heapType = readSignedInt32(); + return switch (heapType) { + case FUNC_HEAPTYPE, EXTERN_HEAPTYPE -> heapType; + case EXN_HEAPTYPE -> { + assertTrue(exceptions, Failure.MALFORMED_HEAP_TYPE); + yield heapType; + } + default -> { + if (heapType < 0 || heapType >= module.typeCount()) { + throw fail(Failure.UNKNOWN_TYPE, "Unknown heap type %d", heapType); + } + assertTrue(typedFunctionReferences, Failure.MALFORMED_HEAP_TYPE); + yield heapType; + } + }; } private void readTableLimits(int[] out) { @@ -3247,11 +3470,7 @@ private void readLimits(int[] out, int max) { break; } default: - if (limitsPrefix < 0) { - fail(Failure.INTEGER_REPRESENTATION_TOO_LONG, "Invalid limits prefix (expected 0x00 or 0x01, got 0x%02X)", limitsPrefix); - } else { - fail(Failure.INTEGER_TOO_LARGE, "Invalid limits prefix (expected 0x00 or 0x01, got 0x%02X)", limitsPrefix); - } + fail(Failure.MALFORMED_LIMITS_FLAGS, "Invalid limits prefix (expected 0x00 or 0x01, got 0x%02X)", limitsPrefix); } } @@ -3288,11 +3507,7 @@ private void readLongLimits(long[] longOut, boolean[] boolOut, int max32Bit, lon } default: { if (!threads) { - if (limitsPrefix < 0) { - fail(Failure.INTEGER_REPRESENTATION_TOO_LONG, "Invalid limits prefix (expected 0x00, 0x01, 0x04, or 0x05, got 0x%02X)", limitsPrefix); - } else { - fail(Failure.INTEGER_TOO_LARGE, "Invalid limits prefix (expected 0x00, 0x01, 0x04, or 0x05, got 0x%02X)", limitsPrefix); - } + fail(Failure.MALFORMED_LIMITS_FLAGS, "Invalid limits prefix (expected 0x00, 0x01, 0x04, or 0x05, got 0x%02X)", limitsPrefix); } else { switch (limitsPrefix) { case 0x02: @@ -3315,11 +3530,7 @@ private void readLongLimits(long[] longOut, boolean[] boolOut, int max32Bit, lon break; } default: - if (limitsPrefix < 0) { - fail(Failure.INTEGER_REPRESENTATION_TOO_LONG, "Invalid limits prefix (expected 0x00-0x07, got 0x%02X)", limitsPrefix); - } else { - fail(Failure.INTEGER_TOO_LARGE, "Invalid limits prefix (expected 0x00-0x07, got 0x%02X)", limitsPrefix); - } + fail(Failure.MALFORMED_LIMITS_FLAGS, "Invalid limits prefix (expected 0x00-0x07, got 0x%02X)", limitsPrefix); } } } @@ -3422,7 +3633,7 @@ public Pair createFunctionDebugBytecode(int functionIndex, Ec final CodeEntry codeEntry = BytecodeParser.readCodeEntry(module, module.bytecode(), codeEntryIndex); offset = module.functionSourceCodeInstructionOffset(functionIndex); final int endOffset = module.functionSourceCodeEndOffset(functionIndex); - final CodeEntry result = readFunction(functionIndex, codeEntry.localTypes(), codeEntry.resultTypes(), endOffset, true, bytecode, codeEntryIndex, offsetToLineIndexMap); + final CodeEntry result = readFunction(functionIndex, codeEntry.localTypes(), endOffset, true, bytecode, codeEntryIndex, offsetToLineIndexMap); return Pair.create(result, bytecode.toArray()); } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryStreamParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryStreamParser.java index 5ba793ba08e2..4ca161b0c6ea 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryStreamParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryStreamParser.java @@ -57,8 +57,9 @@ import com.oracle.truffle.api.nodes.ExplodeLoop; public abstract class BinaryStreamParser { - protected static final int SINGLE_RESULT_VALUE = 0; - protected static final int MULTI_RESULT_VALUE = 1; + protected static final int BLOCK_TYPE_VOID = 0; + protected static final int BLOCK_TYPE_VALTYPE = 1; + protected static final int BLOCK_TYPE_TYPE_INDEX = 2; private static final VarHandle I16LE = MethodHandles.byteArrayViewVarHandle(short[].class, ByteOrder.LITTLE_ENDIAN); private static final VarHandle I32LE = MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.LITTLE_ENDIAN); @@ -294,84 +295,6 @@ protected int offset() { return offset; } - /** - * Reads the block type at the current location. The result is provided as two values. The first - * is the actual value of the block type. The second is an indicator if it is a single result - * type or a multi-value result. - * - * @param result The array used for returning the result. - * - */ - protected void readBlockType(int[] result, boolean allowRefTypes, boolean allowVecType, boolean allowExnType) { - byte type = peek1(data, offset); - switch (type) { - case WasmType.VOID_TYPE: - case WasmType.I32_TYPE: - case WasmType.I64_TYPE: - case WasmType.F32_TYPE: - case WasmType.F64_TYPE: - offset++; - result[0] = type; - result[1] = SINGLE_RESULT_VALUE; - break; - case WasmType.V128_TYPE: - Assert.assertTrue(allowVecType, Failure.MALFORMED_VALUE_TYPE); - offset++; - result[0] = type; - result[1] = SINGLE_RESULT_VALUE; - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - Assert.assertTrue(allowRefTypes, Failure.MALFORMED_VALUE_TYPE); - offset++; - result[0] = type; - result[1] = SINGLE_RESULT_VALUE; - break; - case WasmType.EXNREF_TYPE: - Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); - offset++; - result[0] = type; - result[1] = SINGLE_RESULT_VALUE; - break; - default: - long valueAndLength = peekSignedInt32AndLength(data, offset); - result[0] = value(valueAndLength); - Assert.assertIntGreaterOrEqual(result[0], 0, Failure.UNSPECIFIED_MALFORMED); - result[1] = MULTI_RESULT_VALUE; - offset += length(valueAndLength); - } - } - - protected static byte peekValueType(byte[] data, int offset, boolean allowRefTypes, boolean allowVecType, boolean allowExnType) { - byte b = peek1(data, offset); - switch (b) { - case WasmType.I32_TYPE: - case WasmType.I64_TYPE: - case WasmType.F32_TYPE: - case WasmType.F64_TYPE: - break; - case WasmType.V128_TYPE: - Assert.assertTrue(allowVecType, Failure.MALFORMED_VALUE_TYPE); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - Assert.assertTrue(allowRefTypes, Failure.MALFORMED_VALUE_TYPE); - break; - case WasmType.EXNREF_TYPE: - Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); - break; - default: - Assert.fail(Failure.MALFORMED_VALUE_TYPE, "Invalid value type: 0x%02X", b); - } - return b; - } - - protected byte readValueType(boolean allowRefTypes, boolean allowVecType, boolean allowExnType) { - byte b = peekValueType(data, offset, allowRefTypes, allowVecType, allowExnType); - offset++; - return b; - } - @ExplodeLoop(kind = FULL_EXPLODE_UNTIL_RETURN) public static byte peekLeb128Length(byte[] data, int initialOffset) { int currentOffset = initialOffset; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/GlobalRegistry.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/GlobalRegistry.java index 47d8a54b2cfd..8bfc6fb9b07a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/GlobalRegistry.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/GlobalRegistry.java @@ -46,7 +46,6 @@ import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.globals.WasmGlobal; -import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; /** @@ -117,15 +116,17 @@ private Object loadAsObject(int address) { return objectGlobals[address]; } - public void store(byte globalValueType, int address, Object value) { + public void store(int globalValueType, int address, Object value) { switch (globalValueType) { case WasmType.I32_TYPE -> storeInt(address, (int) value); case WasmType.I64_TYPE -> storeLong(address, (long) value); case WasmType.F32_TYPE -> storeFloat(address, (float) value); case WasmType.F64_TYPE -> storeDouble(address, (double) value); case WasmType.V128_TYPE -> storeVector128(address, (Vector128) value); - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> storeReference(address, value); - default -> throw CompilerDirectives.shouldNotReachHere(); + default -> { + assert WasmType.isReferenceType(globalValueType); + storeReference(address, value); + } } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 8479286032aa..fdd9aeb35378 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -40,8 +40,6 @@ */ package org.graalvm.wasm; -import static org.graalvm.wasm.Assert.assertByteEqual; -import static org.graalvm.wasm.Assert.assertFunctionTypeEquals; import static org.graalvm.wasm.Assert.assertTrue; import static org.graalvm.wasm.Assert.assertUnsignedIntGreaterOrEqual; import static org.graalvm.wasm.Assert.assertUnsignedIntLess; @@ -55,10 +53,8 @@ import static org.graalvm.wasm.BinaryStreamParser.rawPeekI8; import static org.graalvm.wasm.BinaryStreamParser.rawPeekU8; import static org.graalvm.wasm.Linker.ResolutionDag.NO_RESOLVE_ACTION; -import static org.graalvm.wasm.WasmType.EXTERNREF_TYPE; import static org.graalvm.wasm.WasmType.F32_TYPE; import static org.graalvm.wasm.WasmType.F64_TYPE; -import static org.graalvm.wasm.WasmType.FUNCREF_TYPE; import static org.graalvm.wasm.WasmType.I32_TYPE; import static org.graalvm.wasm.WasmType.I64_TYPE; import static org.graalvm.wasm.WasmType.V128_TYPE; @@ -87,11 +83,10 @@ import org.graalvm.wasm.Linker.ResolutionDag.ImportTableSym; import org.graalvm.wasm.Linker.ResolutionDag.ImportTagSym; import org.graalvm.wasm.Linker.ResolutionDag.InitializeGlobalSym; +import org.graalvm.wasm.Linker.ResolutionDag.InitializeTableSym; import org.graalvm.wasm.Linker.ResolutionDag.Resolver; import org.graalvm.wasm.Linker.ResolutionDag.Sym; -import org.graalvm.wasm.SymbolTable.FunctionType; import org.graalvm.wasm.api.ExecuteHostFunctionNode; -import org.graalvm.wasm.api.ValueType; import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.constants.Bytecode; import org.graalvm.wasm.constants.BytecodeBitEncoding; @@ -262,7 +257,7 @@ private static void assignTypeEquivalenceClasses(WasmModule module, WasmLanguage } final SymbolTable symtab = module.symbolTable(); for (int index = 0; index < symtab.typeCount(); index++) { - FunctionType type = symtab.typeAt(index); + SymbolTable.ClosedFunctionType type = symtab.closedFunctionTypeAt(index); int equivalenceClass = language.equivalenceClassFor(type); symtab.setEquivalenceClass(index, equivalenceClass); } @@ -332,20 +327,14 @@ private static void checkFailures(ArrayList failures) { } } - void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int globalIndex, byte valueType, byte mutability, - ImportValueSupplier imports) { + void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int globalIndex, int valueType, byte mutability, ImportValueSupplier imports) { instance.globals().setInitialized(globalIndex, false); final String importedGlobalName = importDescriptor.memberName(); final String importedModuleName = importDescriptor.moduleName(); final Runnable resolveAction = () -> { assert instance.module().globalImported(globalIndex) && globalIndex == importDescriptor.targetIndex() : importDescriptor; WasmGlobal externalGlobal = lookupImportObject(instance, importDescriptor, imports, WasmGlobal.class); - final byte exportedValueType; - final byte exportedMutability; - if (externalGlobal != null) { - exportedValueType = externalGlobal.getValueType().byteValue(); - exportedMutability = externalGlobal.getMutability(); - } else { + if (externalGlobal == null) { final WasmInstance importedInstance = store.lookupModuleInstance(importedModuleName); if (importedInstance == null) { throw WasmException.create(Failure.UNKNOWN_IMPORT, "Module '" + importedModuleName + "', referenced in the import of global variable '" + @@ -359,20 +348,22 @@ void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescripto "', was not exported in the module '" + importedModuleName + "'."); } - exportedValueType = importedInstance.symbolTable().globalValueType(exportedGlobalIndex); - exportedMutability = importedInstance.symbolTable().globalMutability(exportedGlobalIndex); - externalGlobal = importedInstance.externalGlobal(exportedGlobalIndex); } - if (exportedValueType != valueType) { + SymbolTable.ClosedValueType importType = instance.symbolTable().closedTypeOf(valueType); + SymbolTable.ClosedValueType exportType = externalGlobal.getClosedType(); + if (mutability != externalGlobal.getMutability()) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE, "Global variable '" + importedGlobalName + "' is imported into module '" + instance.name() + - "' with the type " + WasmType.toString(valueType) + ", " + - "'but it was exported in the module '" + importedModuleName + "' with the type " + WasmType.toString(exportedValueType) + "."); + "' with the modifier " + GlobalModifier.asString(mutability) + ", " + + "but it was exported in the module '" + importedModuleName + "' with the modifier " + GlobalModifier.asString(externalGlobal.getMutability()) + "."); } - if (exportedMutability != mutability) { + // matching for mutable globals does not work by subtyping, but requires equivalent + // types + if (!(externalGlobal.isMutable() ? importType.equals(exportType) : importType.isSupertypeOf(exportType))) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE, "Global variable '" + importedGlobalName + "' is imported into module '" + instance.name() + - "' with the modifier " + GlobalModifier.asString(mutability) + ", " + - "'but it was exported in the module '" + importedModuleName + "' with the modifier " + GlobalModifier.asString(exportedMutability) + "."); + "' with the type " + GlobalModifier.asString(mutability) + " " + WasmType.toString(valueType) + ", " + + "but it was exported in the module '" + importedModuleName + "' with the type " + GlobalModifier.asString(externalGlobal.getMutability()) + " " + + WasmType.toString(externalGlobal.getType()) + "."); } instance.setExternalGlobal(globalIndex, externalGlobal); instance.globals().setInitialized(globalIndex, true); @@ -393,7 +384,7 @@ private static void initializeGlobal(WasmInstance instance, int globalIndex, Obj assert !instance.globals().isInitialized(globalIndex) : globalIndex; SymbolTable symbolTable = instance.symbolTable(); if (symbolTable.globalExternal(globalIndex)) { - var global = new WasmGlobal(ValueType.fromByteValue(symbolTable.globalValueType(globalIndex)), symbolTable.isGlobalMutable(globalIndex), initValue); + var global = new WasmGlobal(globalIndex, symbolTable, initValue); instance.setExternalGlobal(globalIndex, global); } else { instance.globals().store(symbolTable.globalValueType(globalIndex), symbolTable.globalAddress(globalIndex), initValue); @@ -433,7 +424,7 @@ void resolveFunctionImport(WasmStore store, WasmInstance instance, WasmFunction Object externalFunctionInstance = lookupImportObject(instance, importDescriptor, imports, Object.class); if (externalFunctionInstance != null) { if (externalFunctionInstance instanceof WasmFunctionInstance functionInstance) { - if (!function.type().equals(functionInstance.function().type())) { + if (!function.closedType().isSupertypeOf(functionInstance.function().closedType())) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE); } instance.setTarget(function.index(), functionInstance.target()); @@ -466,7 +457,7 @@ void resolveFunctionImport(WasmStore store, WasmInstance instance, WasmFunction throw WasmException.create(Failure.UNKNOWN_IMPORT, "The imported function '" + function.importedFunctionName() + "', referenced in the module '" + instance.name() + "', does not exist in the imported module '" + function.importedModuleName() + "'."); } - if (!function.type().equals(importedFunction.type())) { + if (!function.closedType().isSupertypeOf(importedFunction.closedType())) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE); } final CallTarget target = importedInstance.target(importedFunction.index()); @@ -542,7 +533,7 @@ void resolveMemoryExport(WasmInstance instance, int memoryIndex, String exported }); } - void resolveTagImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int tagIndex, SymbolTable.FunctionType type, ImportValueSupplier imports) { + void resolveTagImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int tagIndex, SymbolTable.ClosedFunctionType type, ImportValueSupplier imports) { final String importedModuleName = importDescriptor.moduleName(); final String importedTagName = importDescriptor.memberName(); final Runnable resolveAction = () -> { @@ -570,7 +561,8 @@ void resolveTagImport(WasmStore store, WasmInstance instance, ImportDescriptor i } importedTag = importedInstance.tag(exportedTagIndex); } - assertFunctionTypeEquals(type, importedTag.type(), Failure.INCOMPATIBLE_IMPORT_TYPE); + // matching for tag types does not work by subtyping, but requires equivalent types + Assert.assertTrue(type.equals(importedTag.type()), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTag(tagIndex, importedTag); }; resolutionDag.resolveLater(new ImportTagSym(instance.name(), importDescriptor, tagIndex), new Sym[]{new ExportTagSym(importedModuleName, importedTagName)}, resolveAction); @@ -586,7 +578,7 @@ void resolveTagExport(WasmInstance instance, int tagIndex, String exportedTagNam private static Object lookupGlobal(WasmInstance instance, int index) { final SymbolTable symbolTable = instance.symbolTable(); - final byte type = symbolTable.globalValueType(index); + final int type = symbolTable.globalValueType(index); final int globalAddress = symbolTable.globalAddress(index); final GlobalRegistry globals = instance.globals(); if (!globals.isInitialized(index)) { @@ -598,8 +590,10 @@ private static Object lookupGlobal(WasmInstance instance, int index) { case I64_TYPE -> globals.loadAsLong(globalAddress); case F64_TYPE -> globals.loadAsDouble(globalAddress); case V128_TYPE -> globals.loadAsVector128(globalAddress); - case FUNCREF_TYPE, EXTERNREF_TYPE -> globals.loadAsReference(globalAddress); - default -> throw WasmException.create(Failure.UNSPECIFIED_TRAP, "Global variable cannot have the void type."); + default -> { + assert WasmType.isReferenceType(type); + yield globals.loadAsReference(globalAddress); + } }; } @@ -820,7 +814,30 @@ void resolvePassiveDataSegment(WasmStore store, WasmInstance instance, int dataS resolutionDag.resolveLater(new DataSym(instance.name(), dataSegmentId), dependencies.toArray(new Sym[0]), resolveAction); } - void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int tableIndex, int declaredMinSize, int declaredMaxSize, byte elemType, + public static void initializeTable(WasmInstance instance, int tableIndex, Object initValue) { + int tableAddress = instance.tableAddress(tableIndex); + WasmTable table = instance.store().tables().table(tableAddress); + table.fill(0, table.size(), initValue); + } + + void resolveTableInitialization(WasmInstance instance, int tableIndex, byte[] initBytecode, Object initValue) { + final Runnable resolveAction; + final Sym[] dependencies; + if (initValue != null) { + initializeTable(instance, tableIndex, initValue); + resolveAction = NO_RESOLVE_ACTION; + dependencies = ResolutionDag.NO_DEPENDENCIES; + } else if (initBytecode != null) { + resolveAction = () -> initializeTable(instance, tableIndex, evalConstantExpression(instance, initBytecode)); + dependencies = dependenciesOfConstantExpression(instance, initBytecode).toArray(ResolutionDag.NO_DEPENDENCIES); + } else { + resolveAction = NO_RESOLVE_ACTION; + dependencies = ResolutionDag.NO_DEPENDENCIES; + } + resolutionDag.resolveLater(new InitializeTableSym(instance.name(), tableIndex), dependencies, resolveAction); + } + + void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int tableIndex, int declaredMinSize, int declaredMaxSize, int elemType, ImportValueSupplier imports) { final Runnable resolveAction = () -> { WasmTable externalTable = lookupImportObject(instance, importDescriptor, imports, WasmTable.class); @@ -856,16 +873,19 @@ void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor // MAX_TABLE_DECLARATION_SIZE, so this condition will pass. assertUnsignedIntLessOrEqual(declaredMinSize, importedTable.minSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); assertUnsignedIntGreaterOrEqual(declaredMaxSize, importedTable.declaredMaxSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); - assertByteEqual(elemType, importedTable.elemType(), Failure.INCOMPATIBLE_IMPORT_TYPE); + // when matching element types of imported tables, we need to check for type equivalence + // instead of subtyping, as tables have read/write access + assertTrue(instance.symbolTable().closedTypeOf(elemType).equals(importedTable.closedElemType()), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTableAddress(tableIndex, tableAddress); }; + final ImportTableSym importTableSym = new ImportTableSym(instance.name(), importDescriptor); Sym[] dependencies = new Sym[]{new ExportTableSym(importDescriptor.moduleName(), importDescriptor.memberName())}; - resolutionDag.resolveLater(new ImportTableSym(instance.name(), importDescriptor), dependencies, resolveAction); + resolutionDag.resolveLater(importTableSym, dependencies, resolveAction); + resolutionDag.resolveLater(new InitializeTableSym(instance.name(), tableIndex), new Sym[]{importTableSym}, NO_RESOLVE_ACTION); } void resolveTableExport(WasmModule module, int tableIndex, String exportedTableName) { - final ImportDescriptor importDescriptor = module.symbolTable().importedTable(tableIndex); - final Sym[] dependencies = importDescriptor != null ? new Sym[]{new ImportTableSym(module.name(), importDescriptor)} : ResolutionDag.NO_DEPENDENCIES; + final Sym[] dependencies = new Sym[]{new InitializeTableSym(module.name(), tableIndex)}; resolutionDag.resolveLater(new ExportTableSym(module.name(), exportedTableName), dependencies, NO_RESOLVE_ACTION); } @@ -963,9 +983,7 @@ private static Object[] extractElemItems(WasmInstance instance, int bytecodeOffs void resolveElemSegment(WasmStore store, WasmInstance instance, int tableIndex, int elemSegmentId, int offsetAddress, byte[] offsetBytecode, int bytecodeOffset, int elementCount) { final Runnable resolveAction = () -> immediatelyResolveElemSegment(store, instance, tableIndex, offsetAddress, offsetBytecode, bytecodeOffset, elementCount); final ArrayList dependencies = new ArrayList<>(); - if (instance.symbolTable().importedTable(tableIndex) != null) { - dependencies.add(new ImportTableSym(instance.name(), instance.symbolTable().importedTable(tableIndex))); - } + dependencies.add(new InitializeTableSym(instance.name(), tableIndex)); if (elemSegmentId > 0) { dependencies.add(new ElemSym(instance.name(), elemSegmentId - 1)); } @@ -1007,7 +1025,6 @@ void resolvePassiveElemSegment(WasmStore store, WasmInstance instance, int elemS } addElemItemDependencies(instance, bytecodeOffset, elementCount, dependencies); resolutionDag.resolveLater(new ElemSym(instance.name(), elemSegmentId), dependencies.toArray(new Sym[0]), resolveAction); - } public void immediatelyResolvePassiveElementSegment(WasmStore store, WasmInstance instance, int elemSegmentId, int bytecodeOffset, int elementCount) { @@ -1340,7 +1357,7 @@ static class ImportTableSym extends Sym { @Override public String toString() { - return String.format("(import memory %s from %s into %s)", importDescriptor.memberName(), importDescriptor.moduleName(), moduleName); + return String.format("(import table %s from %s into %s)", importDescriptor.memberName(), importDescriptor.moduleName(), moduleName); } @Override @@ -1384,6 +1401,30 @@ public boolean equals(Object object) { } } + static class InitializeTableSym extends Sym { + final int tableIndex; + + InitializeTableSym(String moduleName, int tableIndex) { + super(moduleName); + this.tableIndex = tableIndex; + } + + @Override + public String toString() { + return String.format(Locale.ROOT, "(init table %d in %s)", tableIndex, moduleName); + } + + @Override + public int hashCode() { + return Integer.hashCode(tableIndex) ^ moduleName.hashCode(); + } + + @Override + public boolean equals(Object object) { + return object instanceof InitializeTableSym that && this.tableIndex == that.tableIndex && this.moduleName.equals(that.moduleName); + } + } + static class ElemSym extends Sym { final int elemSegmentId; @@ -1394,7 +1435,7 @@ static class ElemSym extends Sym { @Override public String toString() { - return String.format(Locale.ROOT, "(data %d in %s)", elemSegmentId, moduleName); + return String.format(Locale.ROOT, "(elem %d in %s)", elemSegmentId, moduleName); } @Override diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/ModuleLimits.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/ModuleLimits.java index 4a2e739e6084..8b32814a25ec 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/ModuleLimits.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/ModuleLimits.java @@ -80,7 +80,7 @@ public ModuleLimits(int moduleSizeLimit, int typeCountLimit, int functionCountLi int dataSegmentCountLimit, int elementSegmentCountLimit, int functionSizeLimit, int paramCountLimit, int resultCountLimit, int localCountLimit, int tableInstanceSizeLimit, int memoryInstanceSizeLimit, long memory64InstanceSizeLimit, int tagCountLimit) { this.moduleSizeLimit = minUnsigned(moduleSizeLimit, Integer.MAX_VALUE); - this.typeCountLimit = minUnsigned(typeCountLimit, Integer.MAX_VALUE); + this.typeCountLimit = minUnsigned(typeCountLimit, WasmType.MAX_TYPE_INDEX); this.functionCountLimit = minUnsigned(functionCountLimit, Integer.MAX_VALUE); this.tableCountLimit = minUnsigned(tableCountLimit, Integer.MAX_VALUE); this.multiMemoryCountLimit = minUnsigned(memoryCountLimit, Integer.MAX_VALUE); @@ -109,7 +109,7 @@ private static long minUnsigned(long a, long b) { static final ModuleLimits DEFAULTS = new ModuleLimits( Integer.MAX_VALUE, - Integer.MAX_VALUE, + WasmType.MAX_TYPE_INDEX, Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE, diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index 1119e5483b42..38434d1d2204 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -40,7 +40,6 @@ */ package org.graalvm.wasm; -import static org.graalvm.wasm.Assert.assertByteEqual; import static org.graalvm.wasm.Assert.assertIntEqual; import static org.graalvm.wasm.Assert.assertTrue; import static org.graalvm.wasm.Assert.assertUnsignedIntLess; @@ -51,13 +50,16 @@ import java.util.Arrays; import java.util.List; +import com.oracle.truffle.api.nodes.ExplodeLoop; import org.graalvm.collections.EconomicMap; import org.graalvm.collections.EconomicSet; import org.graalvm.collections.MapCursor; +import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.constants.GlobalModifier; import org.graalvm.wasm.constants.ImportIdentifier; import org.graalvm.wasm.exception.Failure; import org.graalvm.wasm.exception.WasmException; +import org.graalvm.wasm.exception.WasmRuntimeException; import org.graalvm.wasm.memory.WasmMemory; import org.graalvm.wasm.memory.WasmMemoryFactory; @@ -84,71 +86,414 @@ public abstract class SymbolTable { private static final byte GLOBAL_FUNCTION_INITIALIZER_BIT = 0x20; public static final int UNINITIALIZED_ADDRESS = Integer.MIN_VALUE; - private static final int NO_EQUIVALENCE_CLASS = 0; - static final int FIRST_EQUIVALENCE_CLASS = NO_EQUIVALENCE_CLASS + 1; + public static final int NO_EQUIVALENCE_CLASS = 0; + public static final int FIRST_EQUIVALENCE_CLASS = NO_EQUIVALENCE_CLASS + 1; - public static final class FunctionType { - @CompilationFinal(dimensions = 1) private final byte[] paramTypes; - @CompilationFinal(dimensions = 1) private final byte[] resultTypes; - private final int hashCode; + /** + * Represents a WebAssembly value type in its closed form, with all type indices replaced with + * their definitions. You can query the subtyping relation on types using the predicates + * {@link #isSupertypeOf(ClosedValueType)} and {@link #isSubtypeOf(ClosedValueType)}, both of + * which are written to be PE-friendly provided the receiver type is a PE constant. + *

+ * If you need to check whether two types are equivalent, instead of checking + * {@code A.isSupertypeOf(B) && A.isSubtypeOf(B)}, you can use {@code A.equals(B)}, since, in + * the WebAssembly type system, type equivalence corresponds to structural equality. + *

+ */ + public abstract static sealed class ClosedValueType { + // This is a workaround until we can use pattern matching in JDK 21+. + public enum Kind { + Number, + Vector, + Reference + } - FunctionType(byte[] paramTypes, byte[] resultTypes) { - this.paramTypes = paramTypes; - this.resultTypes = resultTypes; - this.hashCode = Arrays.hashCode(paramTypes) ^ Arrays.hashCode(resultTypes); + public abstract boolean isSupertypeOf(ClosedValueType valueSubType); + + public abstract boolean isSubtypeOf(ClosedValueType valueSuperType); + + public abstract boolean matchesValue(Object value); + + public abstract Kind kind(); + } + + public abstract static sealed class ClosedHeapType { + // This is a workaround until we can use pattern matching in JDK 21+. + public enum Kind { + Abstract, + Function + } + + public abstract boolean isSupertypeOf(ClosedHeapType heapSubType); + + public abstract boolean isSubtypeOf(ClosedHeapType heapSuperType); + + public abstract boolean matchesValue(Object value); + + public abstract Kind kind(); + } + + public static final class NumberType extends ClosedValueType { + public static final NumberType I32 = new NumberType(WasmType.I32_TYPE); + public static final NumberType I64 = new NumberType(WasmType.I64_TYPE); + public static final NumberType F32 = new NumberType(WasmType.F32_TYPE); + public static final NumberType F64 = new NumberType(WasmType.F64_TYPE); + + private final int value; + + private NumberType(int value) { + this.value = value; + } + + public int value() { + return value; + } + + @Override + public boolean isSupertypeOf(ClosedValueType valueSubType) { + return valueSubType == this; + } + + @Override + public boolean isSubtypeOf(ClosedValueType valueSuperType) { + return valueSuperType == this; + } + + @Override + public boolean matchesValue(Object val) { + return switch (value()) { + case WasmType.I32_TYPE -> val instanceof Integer; + case WasmType.I64_TYPE -> val instanceof Long; + case WasmType.F32_TYPE -> val instanceof Float; + case WasmType.F64_TYPE -> val instanceof Double; + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } + + @Override + public Kind kind() { + return Kind.Number; } - public static FunctionType create(byte[] paramTypes, byte[] resultTypes) { - return new FunctionType(paramTypes, resultTypes); + @Override + public boolean equals(Object that) { + return this == that; } - public byte[] paramTypes() { + @Override + public int hashCode() { + return value; + } + + @Override + public String toString() { + return switch (value()) { + case WasmType.I32_TYPE -> "i32"; + case WasmType.I64_TYPE -> "i64"; + case WasmType.F32_TYPE -> "f32"; + case WasmType.F64_TYPE -> "f64"; + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } + } + + public static final class VectorType extends ClosedValueType { + public static final VectorType V128 = new VectorType(WasmType.V128_TYPE); + + private final int value; + + private VectorType(int value) { + this.value = value; + } + + public int value() { + return value; + } + + @Override + public boolean isSupertypeOf(ClosedValueType valueSubType) { + return valueSubType == V128; + } + + @Override + public boolean isSubtypeOf(ClosedValueType valueSuperType) { + return valueSuperType == V128; + } + + @Override + public boolean matchesValue(Object val) { + return val instanceof Vector128; + } + + @Override + public Kind kind() { + return Kind.Vector; + } + + @Override + public boolean equals(Object that) { + return this == that; + } + + @Override + public int hashCode() { + return value; + } + + @Override + public String toString() { + return "v128"; + } + } + + public static final class ClosedReferenceType extends ClosedValueType { + public static final ClosedReferenceType FUNCREF = new ClosedReferenceType(true, AbstractHeapType.FUNC); + public static final ClosedReferenceType NONNULL_FUNCREF = new ClosedReferenceType(false, AbstractHeapType.FUNC); + public static final ClosedReferenceType EXTERNREF = new ClosedReferenceType(true, AbstractHeapType.EXTERN); + public static final ClosedReferenceType NONNULL_EXTERNREF = new ClosedReferenceType(false, AbstractHeapType.EXTERN); + public static final ClosedReferenceType EXNREF = new ClosedReferenceType(true, AbstractHeapType.EXN); + public static final ClosedReferenceType NONNULL_EXNREF = new ClosedReferenceType(false, AbstractHeapType.EXN); + + private final boolean nullable; + private final ClosedHeapType closedHeapType; + + public ClosedReferenceType(boolean nullable, ClosedHeapType closedHeapType) { + this.nullable = nullable; + this.closedHeapType = closedHeapType; + } + + public boolean nullable() { + return nullable; + } + + public ClosedHeapType heapType() { + return closedHeapType; + } + + @Override + public boolean isSupertypeOf(ClosedValueType valueSubType) { + return valueSubType instanceof ClosedReferenceType referenceSubType && (!referenceSubType.nullable || this.nullable) && + this.closedHeapType.isSupertypeOf(referenceSubType.closedHeapType); + } + + @Override + public boolean isSubtypeOf(ClosedValueType valueSuperType) { + return valueSuperType instanceof ClosedReferenceType referencedSuperType && (!this.nullable || referencedSuperType.nullable) && + this.closedHeapType.isSubtypeOf(referencedSuperType.closedHeapType); + } + + @Override + public boolean matchesValue(Object value) { + return nullable() && value == WasmConstant.NULL || heapType().matchesValue(value); + } + + @Override + public Kind kind() { + return Kind.Reference; + } + + @Override + public boolean equals(Object obj) { + return obj instanceof ClosedReferenceType that && this.nullable == that.nullable && this.closedHeapType.equals(that.closedHeapType); + } + + @Override + public int hashCode() { + return Boolean.hashCode(nullable) ^ closedHeapType.hashCode(); + } + + @Override + public String toString() { + CompilerAsserts.neverPartOfCompilation(); + if (this == FUNCREF) { + return "funcref"; + } else if (this == EXTERNREF) { + return "externref"; + } else if (this == EXNREF) { + return "exnref"; + } else { + StringBuilder buf = new StringBuilder(); + buf.append("(ref "); + if (nullable) { + buf.append("null "); + } + buf.append(closedHeapType.toString()); + buf.append(")"); + return buf.toString(); + } + } + } + + public static final class AbstractHeapType extends ClosedHeapType { + public static final AbstractHeapType FUNC = new AbstractHeapType(WasmType.FUNC_HEAPTYPE); + public static final AbstractHeapType EXTERN = new AbstractHeapType(WasmType.EXTERN_HEAPTYPE); + public static final AbstractHeapType EXN = new AbstractHeapType(WasmType.EXN_HEAPTYPE); + + private final int value; + + private AbstractHeapType(int value) { + this.value = value; + } + + public int value() { + return value; + } + + @Override + public boolean isSupertypeOf(ClosedHeapType heapSubType) { + return switch (this.value) { + case WasmType.FUNC_HEAPTYPE -> heapSubType == FUNC || heapSubType instanceof ClosedFunctionType; + case WasmType.EXTERN_HEAPTYPE -> heapSubType == EXTERN; + case WasmType.EXN_HEAPTYPE -> heapSubType == EXN; + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } + + @Override + public boolean isSubtypeOf(ClosedHeapType heapSuperType) { + return heapSuperType == this; + } + + @Override + public boolean matchesValue(Object val) { + return switch (this.value) { + case WasmType.FUNC_HEAPTYPE -> val instanceof WasmFunctionInstance; + case WasmType.EXTERN_HEAPTYPE -> true; + case WasmType.EXN_HEAPTYPE -> val instanceof WasmRuntimeException; + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } + + @Override + public Kind kind() { + return Kind.Abstract; + } + + @Override + public boolean equals(Object that) { + return this == that; + } + + @Override + public int hashCode() { + return value; + } + + @Override + public String toString() { + return switch (this.value) { + case WasmType.FUNC_HEAPTYPE -> "func"; + case WasmType.EXTERN_HEAPTYPE -> "extern"; + case WasmType.EXN_HEAPTYPE -> "exn"; + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } + } + + public static final class ClosedFunctionType extends ClosedHeapType { + @CompilationFinal(dimensions = 1) private final ClosedValueType[] paramTypes; + @CompilationFinal(dimensions = 1) private final ClosedValueType[] resultTypes; + + public ClosedFunctionType(ClosedValueType[] paramTypes, ClosedValueType[] resultTypes) { + this.paramTypes = paramTypes; + this.resultTypes = resultTypes; + } + + public ClosedValueType[] paramTypes() { return paramTypes; } - public byte[] resultTypes() { + public ClosedValueType[] resultTypes() { return resultTypes; } @Override - public int hashCode() { - return hashCode; + @ExplodeLoop(kind = ExplodeLoop.LoopExplosionKind.FULL_UNROLL) + public boolean isSupertypeOf(ClosedHeapType heapSubType) { + if (!(heapSubType instanceof ClosedFunctionType functionSubType)) { + return false; + } + if (this.paramTypes.length != functionSubType.paramTypes.length) { + return false; + } + for (int i = 0; i < this.paramTypes.length; i++) { + CompilerAsserts.partialEvaluationConstant(this.paramTypes[i]); + if (!this.paramTypes[i].isSubtypeOf(functionSubType.paramTypes[i])) { + return false; + } + } + if (this.resultTypes.length != functionSubType.resultTypes.length) { + return false; + } + for (int i = 0; i < this.resultTypes.length; i++) { + CompilerAsserts.partialEvaluationConstant(this.resultTypes[i]); + if (!this.resultTypes[i].isSupertypeOf(functionSubType.resultTypes[i])) { + return false; + } + } + return true; } @Override - public boolean equals(Object object) { - if (!(object instanceof FunctionType that)) { + @ExplodeLoop(kind = ExplodeLoop.LoopExplosionKind.FULL_UNROLL) + public boolean isSubtypeOf(ClosedHeapType heapSuperType) { + if (heapSuperType == AbstractHeapType.FUNC) { + return true; + } + if (!(heapSuperType instanceof ClosedFunctionType functionSuperType)) { return false; } - if (this.paramTypes.length != that.paramTypes.length) { + if (this.paramTypes.length != functionSuperType.paramTypes.length) { return false; } for (int i = 0; i < this.paramTypes.length; i++) { - if (this.paramTypes[i] != that.paramTypes[i]) { + CompilerAsserts.partialEvaluationConstant(this.paramTypes[i]); + if (!this.paramTypes[i].isSupertypeOf(functionSuperType.paramTypes[i])) { return false; } } - if (this.resultTypes.length != that.resultTypes.length) { + if (this.resultTypes.length != functionSuperType.resultTypes.length) { return false; } for (int i = 0; i < this.resultTypes.length; i++) { - if (this.resultTypes[i] != that.resultTypes[i]) { + CompilerAsserts.partialEvaluationConstant(this.resultTypes[i]); + if (!this.resultTypes[i].isSubtypeOf(functionSuperType.resultTypes[i])) { return false; } } return true; } + @Override + public boolean matchesValue(Object value) { + return value instanceof WasmFunctionInstance instance && isSupertypeOf(instance.function().closedType()); + } + + @Override + public Kind kind() { + return Kind.Function; + } + + @Override + public boolean equals(Object obj) { + return obj instanceof ClosedFunctionType that && Arrays.equals(this.paramTypes, that.paramTypes) && Arrays.equals(this.resultTypes, that.resultTypes); + } + + @Override + public int hashCode() { + return Arrays.hashCode(paramTypes) ^ Arrays.hashCode(resultTypes); + } + @Override public String toString() { CompilerAsserts.neverPartOfCompilation(); String[] paramNames = new String[paramTypes.length]; for (int i = 0; i < paramTypes.length; i++) { - paramNames[i] = WasmType.toString(paramTypes[i]); + paramNames[i] = paramTypes[i].toString(); } String[] resultNames = new String[resultTypes.length]; for (int i = 0; i < resultTypes.length; i++) { - resultNames[i] = WasmType.toString(resultTypes[i]); + resultNames[i] = resultTypes[i].toString(); } return "(" + String.join(" ", paramNames) + ")->(" + String.join(" ", resultNames) + ")"; } @@ -161,8 +506,12 @@ public String toString() { * Note: this is the upper bound defined by the module. A table instance * might have a lower internal max allowed size in practice. * @param elemType The element type of the table. + * @param initValue The initial value of the table's elements, can be {@code null} if no + * initializer present + * @param initBytecode The bytecode of the table's initializer expression, can be {@code null} + * if no initializer present */ - public record TableInfo(int initialSize, int maximumSize, byte elemType) { + public record TableInfo(int initialSize, int maximumSize, int elemType, Object initValue, byte[] initBytecode) { } /** @@ -214,6 +563,13 @@ public record TagInfo(byte attribute, int typeIndex) { */ @CompilationFinal(dimensions = 1) private int[] typeOffsets; + /** + * Stores the closed forms of all the types defined in this module. Closed forms replace type + * indices with the definitions of the referenced types, resulting in a tree-like data + * structure. + */ + @CompilationFinal(dimensions = 1) private ClosedHeapType[] closedTypes; + /** * Stores the type equivalence class. *

@@ -269,14 +625,18 @@ public record TagInfo(byte attribute, int typeIndex) { @CompilationFinal private int startFunctionIndex; /** - * A global type is the value type of the global, followed by its mutability. This is encoded as - * two bytes -- the lowest (0th) byte is the value type. The 1st byte is organized like this: + * Value types of globals. + */ + @CompilationFinal(dimensions = 1) private int[] globalTypes; + + /** + * Mutability flags of globals. These are encoded like this: *

* * | . | . | . | functionOrIndex flag | reference flag | initialized flag | exported flag | mutable flag | * */ - @CompilationFinal(dimensions = 1) private byte[] globalTypes; + @CompilationFinal(dimensions = 1) private byte[] globalFlags; /** * The values or indices used for initializing globals. @@ -417,7 +777,7 @@ public record TagInfo(byte attribute, int typeIndex) { @CompilationFinal private int codeEntryCount; /** - * All function indices that can be references via + * All function indices that can be referenced via * {@link org.graalvm.wasm.constants.Instructions#REF_FUNC}. */ @CompilationFinal private EconomicSet functionReferences; @@ -426,6 +786,7 @@ public record TagInfo(byte attribute, int typeIndex) { CompilerAsserts.neverPartOfCompilation(); this.typeData = new int[INITIAL_DATA_SIZE]; this.typeOffsets = new int[INITIAL_TYPE_SIZE]; + this.closedTypes = new ClosedHeapType[INITIAL_TYPE_SIZE]; this.typeEquivalenceClasses = new int[INITIAL_TYPE_SIZE]; this.typeDataSize = 0; this.typeCount = 0; @@ -438,7 +799,8 @@ public record TagInfo(byte attribute, int typeIndex) { this.exportedFunctions = EconomicMap.create(); this.exportedFunctionsByIndex = EconomicMap.create(); this.startFunctionIndex = -1; - this.globalTypes = new byte[2 * INITIAL_GLOBALS_SIZE]; + this.globalTypes = new int[INITIAL_GLOBALS_SIZE]; + this.globalFlags = new byte[INITIAL_GLOBALS_SIZE]; this.globalInitializers = new Object[INITIAL_GLOBALS_SIZE]; this.globalInitializersBytecode = new byte[INITIAL_GLOBALS_BYTECODE_SIZE][]; this.importedGlobals = EconomicMap.create(); @@ -473,7 +835,7 @@ private void checkNotParsed() { private void checkUniqueExport(String name) { CompilerAsserts.neverPartOfCompilation(); - if (exportedFunctions.containsKey(name) || exportedGlobals.containsKey(name) || exportedMemories.containsKey(name) || exportedTables.containsKey(name)) { + if (exportedFunctions.containsKey(name) || exportedGlobals.containsKey(name) || exportedMemories.containsKey(name) || exportedTables.containsKey(name) || exportedTags.containsKey(name)) { throw WasmException.create(Failure.DUPLICATE_EXPORT, "All export names must be different, but '" + name + "' is exported twice."); } } @@ -482,18 +844,6 @@ public void checkFunctionIndex(int funcIndex) { assertUnsignedIntLess(funcIndex, numFunctions, Failure.UNKNOWN_FUNCTION); } - private static int[] reallocate(int[] array, int currentSize, int newLength) { - int[] newArray = new int[newLength]; - System.arraycopy(array, 0, newArray, 0, currentSize); - return newArray; - } - - private static WasmFunction[] reallocate(WasmFunction[] array, int currentSize, int newLength) { - WasmFunction[] newArray = new WasmFunction[newLength]; - System.arraycopy(array, 0, newArray, 0, currentSize); - return newArray; - } - /** * Ensure that the {@link #typeData} array has enough space to store {@code index}. If there is * no enough space, then a reallocation of the array takes place, doubling its capacity. @@ -504,7 +854,7 @@ private static WasmFunction[] reallocate(WasmFunction[] array, int currentSize, private void ensureTypeDataCapacity(int index) { if (typeData.length <= index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * typeData.length); - typeData = reallocate(typeData, typeDataSize, newLength); + typeData = Arrays.copyOf(typeData, newLength); } } @@ -519,8 +869,9 @@ private void ensureTypeDataCapacity(int index) { private void ensureTypeCapacity(int index) { if (typeOffsets.length <= index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * typeOffsets.length); - typeOffsets = reallocate(typeOffsets, typeCount, newLength); - typeEquivalenceClasses = reallocate(typeEquivalenceClasses, typeCount, newLength); + typeOffsets = Arrays.copyOf(typeOffsets, newLength); + closedTypes = Arrays.copyOf(closedTypes, newLength); + typeEquivalenceClasses = Arrays.copyOf(typeEquivalenceClasses, newLength); } } @@ -542,7 +893,7 @@ int allocateFunctionType(int paramCount, int resultCount, boolean isMultiValue) return typeIdx; } - public int allocateFunctionType(byte[] paramTypes, byte[] resultTypes, boolean isMultiValue) { + public int allocateFunctionType(int[] paramTypes, int[] resultTypes, boolean isMultiValue) { checkNotParsed(); final int typeIdx = allocateFunctionType(paramTypes.length, resultTypes.length, isMultiValue); for (int i = 0; i < paramTypes.length; i++) { @@ -551,21 +902,34 @@ public int allocateFunctionType(byte[] paramTypes, byte[] resultTypes, boolean i for (int i = 0; i < resultTypes.length; i++) { registerFunctionTypeResultType(typeIdx, i, resultTypes[i]); } + finishFunctionType(typeIdx); return typeIdx; } - void registerFunctionTypeParameterType(int funcTypeIdx, int paramIdx, byte type) { + void registerFunctionTypeParameterType(int funcTypeIdx, int paramIdx, int type) { checkNotParsed(); int idx = 2 + typeOffsets[funcTypeIdx] + paramIdx; typeData[idx] = type; } - void registerFunctionTypeResultType(int funcTypeIdx, int resultIdx, byte type) { + void registerFunctionTypeResultType(int funcTypeIdx, int resultIdx, int type) { checkNotParsed(); int idx = 2 + typeOffsets[funcTypeIdx] + typeData[typeOffsets[funcTypeIdx]] + resultIdx; typeData[idx] = type; } + void finishFunctionType(int funcTypeIdx) { + ClosedValueType[] paramTypes = new ClosedValueType[functionTypeParamCount(funcTypeIdx)]; + for (int i = 0; i < paramTypes.length; i++) { + paramTypes[i] = closedTypeOf(functionTypeParamTypeAt(funcTypeIdx, i)); + } + ClosedValueType[] resultTypes = new ClosedValueType[functionTypeResultCount(funcTypeIdx)]; + for (int i = 0; i < resultTypes.length; i++) { + resultTypes[i] = closedTypeOf(functionTypeResultTypeAt(funcTypeIdx, i)); + } + closedTypes[funcTypeIdx] = new ClosedFunctionType(paramTypes, resultTypes); + } + public int equivalenceClass(int typeIndex) { return typeEquivalenceClasses[typeIndex]; } @@ -581,7 +945,7 @@ void setEquivalenceClass(int index, int eqClass) { private void ensureFunctionsCapacity(int index) { if (functions.length <= index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * functions.length); - functions = reallocate(functions, numFunctions, newLength); + functions = Arrays.copyOf(functions, newLength); } } @@ -655,29 +1019,29 @@ public WasmFunction startFunction() { protected abstract WasmModule module(); - public byte functionTypeParamTypeAt(int typeIndex, int i) { + public int functionTypeParamTypeAt(int typeIndex, int paramIndex) { int typeOffset = typeOffsets[typeIndex]; - return (byte) typeData[typeOffset + 2 + i]; + return typeData[typeOffset + 2 + paramIndex]; } - public byte functionTypeResultTypeAt(int typeIndex, int resultIndex) { + public int functionTypeResultTypeAt(int typeIndex, int resultIndex) { int typeOffset = typeOffsets[typeIndex]; int paramCount = typeData[typeOffset]; - return (byte) typeData[typeOffset + 2 + paramCount + resultIndex]; + return typeData[typeOffset + 2 + paramCount + resultIndex]; } - private byte[] functionTypeParamTypesAsArray(int typeIndex) { + public int[] functionTypeParamTypesAsArray(int typeIndex) { int paramCount = functionTypeParamCount(typeIndex); - byte[] paramTypes = new byte[paramCount]; + int[] paramTypes = new int[paramCount]; for (int i = 0; i < paramCount; ++i) { paramTypes[i] = functionTypeParamTypeAt(typeIndex, i); } return paramTypes; } - private byte[] functionTypeResultTypesAsArray(int typeIndex) { + public int[] functionTypeResultTypesAsArray(int typeIndex) { int resultTypeCount = functionTypeResultCount(typeIndex); - byte[] resultTypes = new byte[resultTypeCount]; + int[] resultTypes = new int[resultTypeCount]; for (int i = 0; i < resultTypeCount; i++) { resultTypes[i] = functionTypeResultTypeAt(typeIndex, i); } @@ -688,8 +1052,96 @@ int typeCount() { return typeCount; } - public FunctionType typeAt(int index) { - return new FunctionType(functionTypeParamTypesAsArray(index), functionTypeResultTypesAsArray(index)); + /** + * Convenience function for calling {@link #closedTypeAt(int)} when the defined type at index + * {@code typeIndex} is known to be a function type. + * + * @see #closedTypeAt(int) + */ + public ClosedFunctionType closedFunctionTypeAt(int typeIndex) { + return (ClosedFunctionType) closedTypeAt(typeIndex); + } + + /** + * Fetches the closed form of a type defined in this module at index {@code typeIndex}. + * + * @param typeIndex index of a type defined in this module + */ + public ClosedHeapType closedTypeAt(int typeIndex) { + return closedTypes[typeIndex]; + } + + /** + * A convenient way of calling {@link #closedTypeOf(int, SymbolTable)} when a + * {@link SymbolTable} is present. + * + * @see #closedTypeOf(int, SymbolTable) + */ + public ClosedValueType closedTypeOf(int type) { + return SymbolTable.closedTypeOf(type, this); + } + + /** + * Maps a type encoded as an {@code int} (as per {@link WasmType}) into its closed form, + * represented as a {@link ClosedValueType}. Any type indices in the type are resolved using the + * provided symbol table. + *

+ * It is legal to call this function with a null {@code symbolTable}. This is used in cases + * where we need to map a predefined value type to the closed type data representation (i.e. we + * know the type is already closed anyway and so it does not contain any type indices). + *

+ * + * @param type the {@code int}-encoded Wasm type to be expanded + * @param symbolTable used for lookup of type definitions when expanding type indices + */ + public static ClosedValueType closedTypeOf(int type, SymbolTable symbolTable) { + return switch (type) { + case WasmType.I32_TYPE -> NumberType.I32; + case WasmType.I64_TYPE -> NumberType.I64; + case WasmType.F32_TYPE -> NumberType.F32; + case WasmType.F64_TYPE -> NumberType.F64; + case WasmType.V128_TYPE -> VectorType.V128; + default -> { + assert WasmType.isReferenceType(type); + boolean nullable = WasmType.isNullable(type); + yield switch (WasmType.getAbstractHeapType(type)) { + case WasmType.FUNC_HEAPTYPE -> nullable ? ClosedReferenceType.FUNCREF : ClosedReferenceType.NONNULL_FUNCREF; + case WasmType.EXTERN_HEAPTYPE -> nullable ? ClosedReferenceType.EXTERNREF : ClosedReferenceType.NONNULL_EXTERNREF; + case WasmType.EXN_HEAPTYPE -> nullable ? ClosedReferenceType.EXNREF : ClosedReferenceType.NONNULL_EXNREF; + default -> { + assert WasmType.isConcreteReferenceType(type); + assert symbolTable != null; + int typeIndex = WasmType.getTypeIndex(type); + ClosedHeapType heapType = symbolTable.closedTypeAt(typeIndex); + yield new ClosedReferenceType(nullable, heapType); + } + }; + } + }; + } + + /** + * Checks whether the type {@code actualType} matches the type {@code expectedType}. This is the + * case when {@code actualType} is a subtype of {@code expectedType}. + */ + public boolean matchesType(int expectedType, int actualType) { + switch (expectedType) { + case WasmType.BOT -> { + return false; + } + case WasmType.TOP -> { + return true; + } + } + switch (actualType) { + case WasmType.BOT -> { + return true; + } + case WasmType.TOP -> { + return false; + } + } + return closedTypeOf(expectedType).isSupertypeOf(closedTypeOf(actualType)); } public void importSymbol(ImportDescriptor descriptor) { @@ -758,11 +1210,14 @@ public WasmFunction importedFunction(ImportDescriptor descriptor) { private void ensureGlobalsCapacity(int index) { while (index >= globalInitializers.length) { - final byte[] nGlobalTypes = new byte[globalTypes.length * 2]; + final int[] nGlobalTypes = new int[globalTypes.length * 2]; + final byte[] nGlobalFlags = new byte[globalFlags.length * 2]; final Object[] nGlobalInitializers = new Object[globalInitializers.length * 2]; System.arraycopy(globalTypes, 0, nGlobalTypes, 0, globalTypes.length); + System.arraycopy(globalFlags, 0, nGlobalFlags, 0, globalFlags.length); System.arraycopy(globalInitializers, 0, nGlobalInitializers, 0, globalInitializers.length); globalTypes = nGlobalTypes; + globalFlags = nGlobalFlags; globalInitializers = nGlobalInitializers; } } @@ -779,8 +1234,7 @@ private void ensureGlobalInitializersBytecodeCapacity(int index) { * Allocates a global index in the symbol table, for a global variable that was already * allocated. */ - void allocateGlobal(int index, byte valueType, byte mutability, boolean initialized, boolean imported, byte[] initBytecode, Object initialValue) { - assert (valueType & 0xff) == valueType; + void allocateGlobal(int index, int valueType, byte mutability, boolean initialized, boolean imported, byte[] initBytecode, Object initialValue) { checkNotParsed(); ensureGlobalsCapacity(index); numGlobals = maxUnsigned(index + 1, numGlobals); @@ -808,8 +1262,8 @@ void allocateGlobal(int index, byte valueType, byte mutability, boolean initiali globalInitializersBytecode[initBytecodeIndex] = initBytecode; globalInitializers[index] = initBytecodeIndex; } - globalTypes[2 * index] = valueType; - globalTypes[2 * index + 1] = flags; + globalTypes[index] = valueType; + globalFlags[index] = flags; } /** @@ -817,7 +1271,7 @@ void allocateGlobal(int index, byte valueType, byte mutability, boolean initiali * default, but may be exported using {@link #exportGlobal}. Imported globals are declared using * {@link #importGlobal} instead. This method may only be called during parsing, before linking. */ - void declareGlobal(int index, byte valueType, byte mutability, boolean initialized, byte[] initBytecode, Object initialValue) { + void declareGlobal(int index, int valueType, byte mutability, boolean initialized, byte[] initBytecode, Object initialValue) { assert initialized == (initBytecode == null) : index; allocateGlobal(index, valueType, mutability, initialized, false, initBytecode, initialValue); module().addLinkAction((context, store, instance, imports) -> { @@ -828,7 +1282,7 @@ void declareGlobal(int index, byte valueType, byte mutability, boolean initializ /** * Declares an imported global. May be re-exported. */ - void importGlobal(String moduleName, String globalName, int index, byte valueType, byte mutability) { + void importGlobal(String moduleName, String globalName, int index, int valueType, byte mutability) { final ImportDescriptor descriptor = new ImportDescriptor(moduleName, globalName, ImportIdentifier.GLOBAL, index, numImportedSymbols()); importedGlobals.put(index, descriptor); importSymbol(descriptor); @@ -879,12 +1333,12 @@ public boolean isGlobalMutable(int index) { return globalMutability(index) == GlobalModifier.MUTABLE; } - public byte globalValueType(int index) { - return globalTypes[2 * index]; + public int globalValueType(int index) { + return globalTypes[index]; } private byte globalFlags(int index) { - return globalTypes[2 * index + 1]; + return globalFlags[index]; } public boolean globalInitialized(int index) { @@ -929,14 +1383,14 @@ void exportGlobal(String name, int index) { if (!globalExternal(index)) { numExternalGlobals++; } - globalTypes[2 * index + 1] |= GLOBAL_EXPORTED_BIT; + globalFlags[index] |= GLOBAL_EXPORTED_BIT; exportedGlobals.put(name, index); module().addLinkAction((context, store, instance, imports) -> { store.linker().resolveGlobalExport(instance.module(), name, index); }); } - public void declareExportedGlobalWithValue(String name, int index, byte valueType, byte mutability, Object value) { + public void declareExportedGlobalWithValue(String name, int index, int valueType, byte mutability, Object value) { checkNotParsed(); declareGlobal(index, valueType, mutability, true, null, value); exportGlobal(name, index); @@ -950,27 +1404,29 @@ private void ensureTableCapacity(int index) { } } - public void allocateTable(int index, int declaredMinSize, int declaredMaxSize, byte elemType, boolean referenceTypes) { + public void declareTable(int index, int declaredMinSize, int declaredMaxSize, int elemType, byte[] initBytecode, Object initValue, boolean referenceTypes) { checkNotParsed(); - addTable(index, declaredMinSize, declaredMaxSize, elemType, referenceTypes); + addTable(index, declaredMinSize, declaredMaxSize, elemType, initValue, initBytecode, referenceTypes); module().addLinkAction((context, store, instance, imports) -> { final int maxAllowedSize = minUnsigned(declaredMaxSize, module().limits().tableInstanceSizeLimit()); module().limits().checkTableInstanceSize(declaredMinSize); final WasmTable wasmTable; if (context.getContextOptions().memoryOverheadMode()) { // Initialize an empty table in memory overhead mode. - wasmTable = new WasmTable(0, 0, 0, elemType); + wasmTable = new WasmTable(0, 0, 0, elemType, this); } else { - wasmTable = new WasmTable(declaredMinSize, declaredMaxSize, maxAllowedSize, elemType); + wasmTable = new WasmTable(declaredMinSize, declaredMaxSize, maxAllowedSize, elemType, this); } final int address = store.tables().register(wasmTable); instance.setTableAddress(index, address); + + store.linker().resolveTableInitialization(instance, index, initBytecode, initValue); }); } - void importTable(String moduleName, String tableName, int index, int initSize, int maxSize, byte elemType, boolean referenceTypes) { + void importTable(String moduleName, String tableName, int index, int initSize, int maxSize, int elemType, boolean referenceTypes) { checkNotParsed(); - addTable(index, initSize, maxSize, elemType, referenceTypes); + addTable(index, initSize, maxSize, elemType, null, null, referenceTypes); final ImportDescriptor importedTable = new ImportDescriptor(moduleName, tableName, ImportIdentifier.TABLE, index, numImportedSymbols()); importedTables.put(index, importedTable); importSymbol(importedTable); @@ -980,13 +1436,13 @@ void importTable(String moduleName, String tableName, int index, int initSize, i }); } - void addTable(int index, int minSize, int maxSize, byte elemType, boolean referenceTypes) { + void addTable(int index, int minSize, int maxSize, int elemType, Object initValue, byte[] initBytecode, boolean referenceTypes) { if (!referenceTypes) { assertTrue(importedTables.isEmpty(), "A table has already been imported in the module.", Failure.MULTIPLE_TABLES); assertTrue(tableCount == 0, "A table has already been declared in the module.", Failure.MULTIPLE_TABLES); } ensureTableCapacity(index); - final TableInfo table = new TableInfo(minSize, maxSize, elemType); + final TableInfo table = new TableInfo(minSize, maxSize, elemType, initValue, initBytecode); tables[index] = table; tableCount++; } @@ -1040,12 +1496,24 @@ public int tableMaximumSize(int index) { return table.maximumSize; } - public byte tableElementType(int index) { + public int tableElementType(int index) { final TableInfo table = tables[index]; assert table != null; return table.elemType; } + public Object tableInitialValue(int index) { + final TableInfo table = tables[index]; + assert table != null; + return table.initValue; + } + + public byte[] tableInitializerBytecode(int index) { + final TableInfo table = tables[index]; + assert table != null; + return table.initBytecode; + } + private void ensureMemoryCapacity(int index) { if (index >= memories.length) { final MemoryInfo[] nMemories = new MemoryInfo[Math.max(Integer.highestOneBit(index) << 1, 2 * memories.length)]; @@ -1164,7 +1632,7 @@ public void allocateTag(int index, byte attribute, int typeIndex) { checkNotParsed(); addTag(index, attribute, typeIndex); module().addLinkAction((context, store, instance, imports) -> { - final WasmTag tag = new WasmTag(typeAt(typeIndex)); + final WasmTag tag = new WasmTag(closedFunctionTypeAt(typeIndex)); instance.setTag(index, tag); }); } @@ -1173,7 +1641,7 @@ public void importTag(String moduleName, String tagName, int index, byte attribu checkNotParsed(); addTag(index, attribute, typeIndex); final ImportDescriptor importedTag = new ImportDescriptor(moduleName, tagName, ImportIdentifier.TAG, index, numImportedSymbols()); - final FunctionType type = typeAt(typeIndex); + final ClosedFunctionType type = closedFunctionTypeAt(typeIndex); importedTags.put(index, importedTag); importSymbol(importedTag); module().addLinkAction((context, store, instance, imports) -> { @@ -1314,8 +1782,8 @@ public void checkElemIndex(int elemIndex) { assertUnsignedIntLess(elemIndex, elemSegmentCount, Failure.UNKNOWN_ELEM_SEGMENT); } - public void checkElemType(int elemIndex, byte expectedType) { - assertByteEqual(expectedType, (byte) elemInstances[elemIndex], Failure.TYPE_MISMATCH); + public void checkElemType(int elemIndex, int expectedType) { + Assert.assertTrue(matchesType(expectedType, (int) elemInstances[elemIndex]), Failure.TYPE_MISMATCH); } private void ensureElemInstanceCapacity(int index) { @@ -1328,9 +1796,9 @@ private void ensureElemInstanceCapacity(int index) { } } - void setElemInstance(int index, int offset, byte elemType) { + void setElemInstance(int index, int offset, int elemType) { ensureElemInstanceCapacity(index); - elemInstances[index] = (long) offset << 32 | (elemType & 0xFF); + elemInstances[index] = (long) offset << 32 | (elemType & 0xFFFF_FFFFL); elemSegmentCount++; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java index 6dd87c4a6a35..4e8636195de9 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java @@ -47,15 +47,16 @@ public final class WasmCodeEntry { private final WasmFunction function; @CompilationFinal(dimensions = 1) private final byte[] bytecode; - @CompilationFinal(dimensions = 1) private final byte[] localTypes; - @CompilationFinal(dimensions = 1) private final byte[] resultTypes; + @CompilationFinal(dimensions = 1) private final int[] localTypes; + @CompilationFinal(dimensions = 1) private final int[] resultTypes; private final BranchProfile errorBranch = BranchProfile.create(); private final BranchProfile exceptionBranch = BranchProfile.create(); + private final BranchProfile subtypingBranch = BranchProfile.create(); private final int numLocals; private final int resultCount; private final boolean usesMemoryZero; - public WasmCodeEntry(WasmFunction function, byte[] bytecode, byte[] localTypes, byte[] resultTypes, boolean usesMemoryZero) { + public WasmCodeEntry(WasmFunction function, byte[] bytecode, int[] localTypes, int[] resultTypes, boolean usesMemoryZero) { this.function = function; this.bytecode = bytecode; this.localTypes = localTypes; @@ -73,7 +74,7 @@ public byte[] bytecode() { return bytecode; } - public byte localType(int index) { + public int localType(int index) { return localTypes[index]; } @@ -89,7 +90,7 @@ public int resultCount() { return resultCount; } - public byte resultType(int index) { + public int resultType(int index) { return resultTypes[index]; } @@ -101,6 +102,10 @@ public void exceptionBranch() { exceptionBranch.enter(); } + public void subtypingBranch() { + subtypingBranch.enter(); + } + public boolean usesMemoryZero() { return usesMemoryZero; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmContextOptions.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmContextOptions.java index 70667ba2e074..085bd6f34cb1 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmContextOptions.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmContextOptions.java @@ -61,6 +61,7 @@ public final class WasmContextOptions { @CompilationFinal private boolean simd; @CompilationFinal private boolean relaxedSimd; @CompilationFinal private boolean exceptions; + @CompilationFinal private boolean typedFunctionReferences; @CompilationFinal private boolean memoryOverheadMode; @CompilationFinal private boolean constantRandomGet; @@ -93,6 +94,7 @@ private void setOptionValues() { this.simd = readBooleanOption(WasmOptions.SIMD); this.relaxedSimd = readBooleanOption(WasmOptions.RelaxedSIMD); this.exceptions = readBooleanOption(WasmOptions.Exceptions); + this.typedFunctionReferences = readBooleanOption(WasmOptions.TypedFunctionReferences); this.memoryOverheadMode = readBooleanOption(WasmOptions.MemoryOverheadMode); this.constantRandomGet = readBooleanOption(WasmOptions.WasiConstantRandomGet); this.directByteBufferMemoryAccess = readBooleanOption(WasmOptions.DirectByteBufferMemoryAccess); @@ -165,6 +167,10 @@ public boolean supportExceptions() { return exceptions; } + public boolean supportTypedFunctionReferences() { + return typedFunctionReferences; + } + public boolean memoryOverheadMode() { return memoryOverheadMode; } @@ -199,6 +205,7 @@ public int hashCode() { hash = 53 * hash + (this.simd ? 1 : 0); hash = 53 * hash + (this.relaxedSimd ? 1 : 0); hash = 53 * hash + (this.exceptions ? 1 : 0); + hash = 53 * hash + (this.typedFunctionReferences ? 1 : 0); hash = 53 * hash + (this.memoryOverheadMode ? 1 : 0); hash = 53 * hash + (this.constantRandomGet ? 1 : 0); hash = 53 * hash + (this.directByteBufferMemoryAccess ? 1 : 0); @@ -251,6 +258,9 @@ public boolean equals(Object obj) { if (this.exceptions != other.exceptions) { return false; } + if (this.typedFunctionReferences != other.typedFunctionReferences) { + return false; + } if (this.memoryOverheadMode != other.memoryOverheadMode) { return false; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java index 4a96698b673e..a54608092eea 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java @@ -43,12 +43,15 @@ import com.oracle.truffle.api.CallTarget; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import org.graalvm.wasm.exception.Failure; +import org.graalvm.wasm.exception.WasmException; public final class WasmFunction { private final SymbolTable symbolTable; private final int index; private final ImportDescriptor importDescriptor; private final int typeIndex; + private final SymbolTable.ClosedFunctionType closedFunctionType; @CompilationFinal private int typeEquivalenceClass; @CompilationFinal private String debugName; @CompilationFinal private CallTarget callTarget; @@ -63,7 +66,7 @@ public WasmFunction(SymbolTable symbolTable, int index, int typeIndex, ImportDes this.index = index; this.importDescriptor = importDescriptor; this.typeIndex = typeIndex; - this.typeEquivalenceClass = -1; + this.closedFunctionType = symbolTable.closedFunctionTypeAt(typeIndex); } public String moduleName() { @@ -74,22 +77,33 @@ public int paramCount() { return symbolTable.functionTypeParamCount(typeIndex); } - public byte paramTypeAt(int argumentIndex) { + public int paramTypeAt(int argumentIndex) { return symbolTable.functionTypeParamTypeAt(typeIndex, argumentIndex); } + public int[] paramTypes() { + return symbolTable.functionTypeParamTypesAsArray(typeIndex); + } + public int resultCount() { return symbolTable.functionTypeResultCount(typeIndex); } - public byte resultTypeAt(int returnIndex) { + public int resultTypeAt(int returnIndex) { return symbolTable.functionTypeResultTypeAt(typeIndex, returnIndex); } void setTypeEquivalenceClass(int typeEquivalenceClass) { + if (this.typeEquivalenceClass != SymbolTable.NO_EQUIVALENCE_CLASS) { + throw WasmException.create(Failure.UNSPECIFIED_INVALID, "Function at index " + index + " already has an equivalence class."); + } this.typeEquivalenceClass = typeEquivalenceClass; } + public int[] resultTypes() { + return symbolTable.functionTypeResultTypesAsArray(typeIndex); + } + @Override public String toString() { return name(); @@ -142,8 +156,8 @@ public int typeIndex() { return typeIndex; } - public SymbolTable.FunctionType type() { - return symbolTable.typeAt(typeIndex()); + public SymbolTable.ClosedFunctionType closedType() { + return closedFunctionType; } public int typeEquivalenceClass() { @@ -178,7 +192,7 @@ public CallTarget getOrCreateInteropCallAdapter(WasmLanguage language) { CallTarget callAdapter = this.interopCallAdapter; if (callAdapter == null) { // Benign initialization race: The call target will be the same each time. - callAdapter = language.interopCallAdapterFor(type()); + callAdapter = language.interopCallAdapterFor(closedType()); this.interopCallAdapter = callAdapter; } return callAdapter; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java index 107f9085a63c..c9553ae9fc29 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java @@ -111,7 +111,7 @@ static List recreateLinkActions(WasmModule module) { final EconomicMap importedGlobals = module.importedGlobals(); for (int i = 0; i < module.numGlobals(); i++) { final int globalIndex = i; - final byte globalValueType = module.globalValueType(globalIndex); + final int globalValueType = module.globalValueType(globalIndex); final byte globalMutability = module.globalMutability(globalIndex); if (importedGlobals.containsKey(globalIndex)) { final ImportDescriptor globalDescriptor = importedGlobals.get(globalIndex); @@ -139,7 +139,7 @@ static List recreateLinkActions(WasmModule module) { final int tableIndex = i; final int tableMinSize = module.tableInitialSize(tableIndex); final int tableMaxSize = module.tableMaximumSize(tableIndex); - final byte tableElemType = module.tableElementType(tableIndex); + final int tableElemType = module.tableElementType(tableIndex); final ImportDescriptor tableDescriptor = module.importedTable(tableIndex); if (tableDescriptor != null) { linkActions.add((context, store, instance, imports) -> { @@ -151,9 +151,13 @@ static List recreateLinkActions(WasmModule module) { final ModuleLimits limits = instance.module().limits(); final int maxAllowedSize = WasmMath.minUnsigned(tableMaxSize, limits.tableInstanceSizeLimit()); limits.checkTableInstanceSize(tableMinSize); - final WasmTable wasmTable = new WasmTable(tableMinSize, tableMaxSize, maxAllowedSize, tableElemType); + final WasmTable wasmTable = new WasmTable(tableMinSize, tableMaxSize, maxAllowedSize, tableElemType, module); final int address = store.tables().register(wasmTable); instance.setTableAddress(tableIndex, address); + + final byte[] initBytecode = module.tableInitializerBytecode(tableIndex); + final Object initValue = module.tableInitialValue(tableIndex); + store.linker().resolveTableInitialization(instance, tableIndex, initBytecode, initValue); }); } } @@ -201,7 +205,7 @@ static List recreateLinkActions(WasmModule module) { for (int i = 0; i < module.tagCount(); i++) { final int tagIndex = i; final int typeIndex = module.tagTypeIndex(tagIndex); - final SymbolTable.FunctionType type = module.typeAt(typeIndex); + final SymbolTable.ClosedFunctionType type = module.closedFunctionTypeAt(typeIndex); final ImportDescriptor tagDescriptor = module.importedTag(tagIndex); if (tagDescriptor != null) { linkActions.add((context, store, instance, imports) -> { @@ -225,117 +229,26 @@ static List recreateLinkActions(WasmModule module) { final byte[] bytecode = module.bytecode(); - for (int i = 0; i < module.dataInstanceCount(); i++) { - final int dataIndex = i; - final int dataOffset = module.dataInstanceOffset(dataIndex); - final int encoding = bytecode[dataOffset]; - int effectiveOffset = dataOffset + 1; + for (int i = 0; i < module.elemInstanceCount(); i++) { + final int elemIndex = i; + final int elemOffset = module.elemInstanceOffset(elemIndex); + final int encoding = bytecode[elemOffset]; + final int typeLengthAndMode = bytecode[elemOffset + 1]; + int effectiveOffset = elemOffset + 2; - final int dataMode = encoding & BytecodeBitEncoding.DATA_SEG_MODE_VALUE; - final int dataLength; - switch (encoding & BytecodeBitEncoding.DATA_SEG_LENGTH_MASK) { - case BytecodeBitEncoding.DATA_SEG_LENGTH_U8: - dataLength = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); + final int elemMode = typeLengthAndMode & BytecodeBitEncoding.ELEM_SEG_MODE_VALUE; + + switch (typeLengthAndMode & BytecodeBitEncoding.ELEM_SEG_TYPE_MASK) { + case BytecodeBitEncoding.ELEM_SEG_TYPE_I8: effectiveOffset++; break; - case BytecodeBitEncoding.DATA_SEG_LENGTH_U16: - dataLength = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); + case BytecodeBitEncoding.ELEM_SEG_TYPE_I16: effectiveOffset += 2; break; - case BytecodeBitEncoding.DATA_SEG_LENGTH_I32: - dataLength = BinaryStreamParser.rawPeekI32(bytecode, effectiveOffset); + case BytecodeBitEncoding.ELEM_SEG_TYPE_I32: effectiveOffset += 4; break; - default: - throw CompilerDirectives.shouldNotReachHere(); } - if (dataMode == SegmentMode.ACTIVE) { - final long value; - switch (encoding & BytecodeBitEncoding.DATA_SEG_VALUE_MASK) { - case BytecodeBitEncoding.DATA_SEG_VALUE_UNDEFINED: - value = -1; - break; - case BytecodeBitEncoding.DATA_SEG_VALUE_U8: - value = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); - effectiveOffset++; - break; - case BytecodeBitEncoding.DATA_SEG_VALUE_U16: - value = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); - effectiveOffset += 2; - break; - case BytecodeBitEncoding.DATA_SEG_VALUE_U32: - value = BinaryStreamParser.rawPeekU32(bytecode, effectiveOffset); - effectiveOffset += 4; - break; - case BytecodeBitEncoding.DATA_SEG_VALUE_I64: - value = BinaryStreamParser.rawPeekI64(bytecode, effectiveOffset); - effectiveOffset += 8; - break; - default: - throw CompilerDirectives.shouldNotReachHere(); - } - final byte[] dataOffsetBytecode; - final long dataOffsetAddress; - if ((encoding & BytecodeBitEncoding.DATA_SEG_BYTECODE_OR_OFFSET_MASK) == BytecodeBitEncoding.DATA_SEG_BYTECODE && - ((encoding & BytecodeBitEncoding.DATA_SEG_VALUE_MASK) != BytecodeBitEncoding.DATA_SEG_VALUE_UNDEFINED)) { - int dataOffsetBytecodeLength = (int) value; - dataOffsetBytecode = Arrays.copyOfRange(bytecode, effectiveOffset, effectiveOffset + dataOffsetBytecodeLength); - effectiveOffset += dataOffsetBytecodeLength; - dataOffsetAddress = -1; - } else { - dataOffsetBytecode = null; - dataOffsetAddress = value; - } - - final int memoryIndex; - if ((encoding & BytecodeBitEncoding.DATA_SEG_HAS_MEMORY_INDEX_ZERO) != 0) { - memoryIndex = 0; - } else { - final int memoryIndexEncoding = bytecode[effectiveOffset]; - effectiveOffset++; - switch (memoryIndexEncoding & BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_MASK) { - case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U6: - memoryIndex = memoryIndexEncoding & BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_VALUE; - break; - case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U8: - memoryIndex = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); - effectiveOffset++; - break; - case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U16: - memoryIndex = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); - effectiveOffset += 2; - break; - case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_I32: - memoryIndex = BinaryStreamParser.rawPeekI32(bytecode, effectiveOffset); - effectiveOffset += 4; - break; - default: - throw CompilerDirectives.shouldNotReachHere(); - } - } - - final int dataBytecodeOffset = effectiveOffset; - linkActions.add((context, store, instance, imports) -> { - store.linker().resolveDataSegment(store, instance, dataIndex, memoryIndex, dataOffsetAddress, dataOffsetBytecode, dataLength, - dataBytecodeOffset, instance.droppedDataInstanceOffset()); - }); - } else { - final int dataBytecodeOffset = effectiveOffset; - linkActions.add((context, store, instance, imports) -> { - store.linker().resolvePassiveDataSegment(store, instance, dataIndex, dataBytecodeOffset); - }); - } - } - - for (int i = 0; i < module.elemInstanceCount(); i++) { - final int elemIndex = i; - final int elemOffset = module.elemInstanceOffset(elemIndex); - final int encoding = bytecode[elemOffset]; - final int typeAndMode = bytecode[elemOffset + 1]; - int effectiveOffset = elemOffset + 2; - - final int elemMode = typeAndMode & BytecodeBitEncoding.ELEM_SEG_MODE_VALUE; - final int elemCount; switch (encoding & BytecodeBitEncoding.ELEM_SEG_COUNT_MASK) { case BytecodeBitEncoding.ELEM_SEG_COUNT_U8: @@ -435,6 +348,108 @@ static List recreateLinkActions(WasmModule module) { } } + for (int i = 0; i < module.dataInstanceCount(); i++) { + final int dataIndex = i; + final int dataOffset = module.dataInstanceOffset(dataIndex); + final int encoding = bytecode[dataOffset]; + int effectiveOffset = dataOffset + 1; + + final int dataMode = encoding & BytecodeBitEncoding.DATA_SEG_MODE_VALUE; + final int dataLength; + switch (encoding & BytecodeBitEncoding.DATA_SEG_LENGTH_MASK) { + case BytecodeBitEncoding.DATA_SEG_LENGTH_U8: + dataLength = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); + effectiveOffset++; + break; + case BytecodeBitEncoding.DATA_SEG_LENGTH_U16: + dataLength = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); + effectiveOffset += 2; + break; + case BytecodeBitEncoding.DATA_SEG_LENGTH_I32: + dataLength = BinaryStreamParser.rawPeekI32(bytecode, effectiveOffset); + effectiveOffset += 4; + break; + default: + throw CompilerDirectives.shouldNotReachHere(); + } + if (dataMode == SegmentMode.ACTIVE) { + final long value; + switch (encoding & BytecodeBitEncoding.DATA_SEG_VALUE_MASK) { + case BytecodeBitEncoding.DATA_SEG_VALUE_UNDEFINED: + value = -1; + break; + case BytecodeBitEncoding.DATA_SEG_VALUE_U8: + value = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); + effectiveOffset++; + break; + case BytecodeBitEncoding.DATA_SEG_VALUE_U16: + value = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); + effectiveOffset += 2; + break; + case BytecodeBitEncoding.DATA_SEG_VALUE_U32: + value = BinaryStreamParser.rawPeekU32(bytecode, effectiveOffset); + effectiveOffset += 4; + break; + case BytecodeBitEncoding.DATA_SEG_VALUE_I64: + value = BinaryStreamParser.rawPeekI64(bytecode, effectiveOffset); + effectiveOffset += 8; + break; + default: + throw CompilerDirectives.shouldNotReachHere(); + } + final byte[] dataOffsetBytecode; + final long dataOffsetAddress; + if ((encoding & BytecodeBitEncoding.DATA_SEG_BYTECODE_OR_OFFSET_MASK) == BytecodeBitEncoding.DATA_SEG_BYTECODE && + ((encoding & BytecodeBitEncoding.DATA_SEG_VALUE_MASK) != BytecodeBitEncoding.DATA_SEG_VALUE_UNDEFINED)) { + int dataOffsetBytecodeLength = (int) value; + dataOffsetBytecode = Arrays.copyOfRange(bytecode, effectiveOffset, effectiveOffset + dataOffsetBytecodeLength); + effectiveOffset += dataOffsetBytecodeLength; + dataOffsetAddress = -1; + } else { + dataOffsetBytecode = null; + dataOffsetAddress = value; + } + + final int memoryIndex; + if ((encoding & BytecodeBitEncoding.DATA_SEG_HAS_MEMORY_INDEX_ZERO) != 0) { + memoryIndex = 0; + } else { + final int memoryIndexEncoding = bytecode[effectiveOffset]; + effectiveOffset++; + switch (memoryIndexEncoding & BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_MASK) { + case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U6: + memoryIndex = memoryIndexEncoding & BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_VALUE; + break; + case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U8: + memoryIndex = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); + effectiveOffset++; + break; + case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U16: + memoryIndex = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); + effectiveOffset += 2; + break; + case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_I32: + memoryIndex = BinaryStreamParser.rawPeekI32(bytecode, effectiveOffset); + effectiveOffset += 4; + break; + default: + throw CompilerDirectives.shouldNotReachHere(); + } + } + + final int dataBytecodeOffset = effectiveOffset; + linkActions.add((context, store, instance, imports) -> { + store.linker().resolveDataSegment(store, instance, dataIndex, memoryIndex, dataOffsetAddress, dataOffsetBytecode, dataLength, + dataBytecodeOffset, instance.droppedDataInstanceOffset()); + }); + } else { + final int dataBytecodeOffset = effectiveOffset; + linkActions.add((context, store, instance, imports) -> { + store.linker().resolvePassiveDataSegment(store, instance, dataIndex, dataBytecodeOffset); + }); + } + } + return linkActions; } @@ -500,7 +515,7 @@ private static void resolveInstantiatedCodeEntries(WasmStore store, WasmInstance } } - private static FrameDescriptor createFrameDescriptor(byte[] localTypes, int maxStackSize) { + private static FrameDescriptor createFrameDescriptor(int[] localTypes, int maxStackSize) { FrameDescriptor.Builder builder = FrameDescriptor.newBuilder(localTypes.length); builder.addSlots(localTypes.length + maxStackSize, FrameSlotKind.Static); return builder.build(); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java index 6e009ea98ce9..514c954372c3 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java @@ -95,11 +95,11 @@ public final class WasmLanguage extends TruffleLanguage { private final Map builtinModules = new ConcurrentHashMap<>(); - private final Map equivalenceClasses = new ConcurrentHashMap<>(); + private final Map equivalenceClasses = new ConcurrentHashMap<>(); private int nextEquivalenceClass = SymbolTable.FIRST_EQUIVALENCE_CLASS; - private final Map interopCallAdapters = new ConcurrentHashMap<>(); + private final Map interopCallAdapters = new ConcurrentHashMap<>(); - public int equivalenceClassFor(SymbolTable.FunctionType type) { + public int equivalenceClassFor(SymbolTable.ClosedFunctionType type) { CompilerAsserts.neverPartOfCompilation(); Integer equivalenceClass = equivalenceClasses.get(type); if (equivalenceClass == null) { @@ -119,7 +119,7 @@ public int equivalenceClassFor(SymbolTable.FunctionType type) { * Gets or creates the interop call adapter for a function type. Always returns the same call * target for any particular type. */ - public CallTarget interopCallAdapterFor(SymbolTable.FunctionType type) { + public CallTarget interopCallAdapterFor(SymbolTable.ClosedFunctionType type) { CompilerAsserts.neverPartOfCompilation(); CallTarget callAdapter = interopCallAdapters.get(type); if (callAdapter == null) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmOptions.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmOptions.java index 3b8100ebea56..54dc2e7f2538 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmOptions.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmOptions.java @@ -139,6 +139,9 @@ public enum ConstantsStorePolicy { @Option(help = "Enable support for exception handling", category = OptionCategory.EXPERT, stability = OptionStability.EXPERIMENTAL, usageSyntax = "false|true") // public static final OptionKey Exceptions = new OptionKey<>(false); + @Option(help = "Enable support for typed function references", category = OptionCategory.EXPERT, stability = OptionStability.EXPERIMENTAL, usageSyntax = "false|true") // + public static final OptionKey TypedFunctionReferences = new OptionKey<>(false); + @Option(help = "In this mode memories and tables are not initialized.", category = OptionCategory.INTERNAL, stability = OptionStability.EXPERIMENTAL, usageSyntax = "false|true") // public static final OptionKey MemoryOverheadMode = new OptionKey<>(false); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java index dc5d0813bd05..85eafc7d5d83 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java @@ -65,7 +65,13 @@ public final class WasmTable extends EmbedderDataHolder implements TruffleObject /** * @see #elemType() */ - private final byte elemType; + private final int elemType; + + /** + * For resolving {@link #elemType} in {@link #closedElemType()}. Can be {@code null} for tables + * allocated from JS. + */ + private final SymbolTable symbolTable; /** * @see #minSize() @@ -86,7 +92,7 @@ public final class WasmTable extends EmbedderDataHolder implements TruffleObject private Object[] elements; @TruffleBoundary - private WasmTable(int declaredMinSize, int declaredMaxSize, int initialSize, int maxAllowedSize, byte elemType, Object initialValue) { + private WasmTable(int declaredMinSize, int declaredMaxSize, int initialSize, int maxAllowedSize, int elemType, Object initialValue, SymbolTable symbolTable) { assert compareUnsigned(declaredMinSize, initialSize) <= 0; assert compareUnsigned(initialSize, maxAllowedSize) <= 0; assert compareUnsigned(maxAllowedSize, declaredMaxSize) <= 0; @@ -101,14 +107,15 @@ private WasmTable(int declaredMinSize, int declaredMaxSize, int initialSize, int this.elements = new Object[declaredMinSize]; Arrays.fill(this.elements, initialValue); this.elemType = elemType; + this.symbolTable = symbolTable; } - public WasmTable(int declaredMinSize, int declaredMaxSize, int maxAllowedSize, byte elemType) { - this(declaredMinSize, declaredMaxSize, declaredMinSize, maxAllowedSize, elemType, WasmConstant.NULL); + public WasmTable(int declaredMinSize, int declaredMaxSize, int maxAllowedSize, int elemType, SymbolTable symbolTable) { + this(declaredMinSize, declaredMaxSize, declaredMinSize, maxAllowedSize, elemType, WasmConstant.NULL, symbolTable); } - public WasmTable(int declaredMinSize, int declaredMaxSize, int maxAllowedSize, byte elemType, Object initialValue) { - this(declaredMinSize, declaredMaxSize, declaredMinSize, maxAllowedSize, elemType, initialValue); + public WasmTable(int declaredMinSize, int declaredMaxSize, int maxAllowedSize, int elemType, Object initialValue) { + this(declaredMinSize, declaredMaxSize, declaredMinSize, maxAllowedSize, elemType, initialValue, null); } /** @@ -156,12 +163,20 @@ public int declaredMaxSize() { *

* This table can only be imported with an equivalent elem type. * - * @return Either {@link WasmType#FUNCREF_TYPE} or {@link WasmType#EXTERNREF_TYPE}. + * @return Either {@link WasmType#FUNCREF_TYPE}, {@link WasmType#EXTERNREF_TYPE} or some + * concrete reference type. */ - public byte elemType() { + public int elemType() { return elemType; } + /** + * The closed form of the type of the elements in the table. + */ + public SymbolTable.ClosedValueType closedElemType() { + return SymbolTable.closedTypeOf(elemType, symbolTable); + } + /** * The current minimum size of the table. The size can change based on calls to * {@link #grow(int, Object)}. diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTag.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTag.java index 1e0ecf0afd03..f760faf8a368 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTag.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTag.java @@ -47,13 +47,13 @@ public static final class Attribute { public static final int EXCEPTION = 0; } - private final SymbolTable.FunctionType type; + private final SymbolTable.ClosedFunctionType type; - public WasmTag(SymbolTable.FunctionType type) { + public WasmTag(SymbolTable.ClosedFunctionType type) { this.type = type; } - public SymbolTable.FunctionType type() { + public SymbolTable.ClosedFunctionType type() { return type; } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java index 9edc05402c12..d31688f20529 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java @@ -40,9 +40,6 @@ */ package org.graalvm.wasm; -import org.graalvm.wasm.exception.Failure; -import org.graalvm.wasm.exception.WasmException; - import com.oracle.truffle.api.CompilerAsserts; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; @@ -52,42 +49,108 @@ import com.oracle.truffle.api.library.ExportLibrary; import com.oracle.truffle.api.library.ExportMessage; +/** + *

+ * Wasm value types are represented as {@code int}s. Predefined types are negative values, while + * user-defined types are non-negative. The second-highest bit is used to signal if a reference type + * is nullable. For the predefined non-reference types (numbers and vectors), that bit is always + * set. + *

+ *

+ * For predefined types, the negative values are the LEB128 decodings of the bytes that represent + * these predefined types in the wasm binary format. E.g., {@code i32} is represented as + * {@code 0x7f} in the binary format, which decodes into {@code -1}, and that is the internal + * representation we use in {@link #I32_TYPE}. + *

+ *

+ * For user-defined types, the non-negative value is the type index which points to the type + * definition in the module's {@link SymbolTable}. + *

+ *

+ * Wasm heap types are represented using the same schema (a union of negative predefined types and + * user-defined non-negative type indices), but without using a special nullability bit. To build a + * reference type out of a heap type, set the nullability bit using {@link #withNullable}: + *

+ * + *
+ *     boolean nullable = ...;
+ *     int heapType = ...;
+ *     int referenceType = WasmType.withNullable(nullable, heapType);
+ * 
+ *

+ * During type checking, it is not enough to compare types by equality, as this does not take into + * account type aliases and subtyping. Instead, use {@link SymbolTable#matchesType(int, int)}. + *

+ *

+ * For an example of how to do case analysis on a Wasm value type, check the source of + * {@link #toString(int)}. NB: The types {@link #TOP} and {@link #BOT} only occur during module + * validation. + *

+ */ @ExportLibrary(InteropLibrary.class) @SuppressWarnings({"unused", "static-method"}) public class WasmType implements TruffleObject { - public static final byte VOID_TYPE = 0x40; - @CompilationFinal(dimensions = 1) public static final byte[] VOID_TYPE_ARRAY = {}; - public static final byte NULL_TYPE = 0x00; - - public static final byte UNKNOWN_TYPE = -1; + private static final int TYPE_INDEX_BITS = 30; + private static final int TYPE_VALUE_MASK = (1 << TYPE_INDEX_BITS) - 1; + private static final int TYPE_NULLABLE_MASK = 1 << TYPE_INDEX_BITS; + private static final int TYPE_PREDEFINED_MASK = 1 << (TYPE_INDEX_BITS + 1); + public static final int MAX_TYPE_INDEX = TYPE_VALUE_MASK; /** * Number Types. */ - public static final byte I32_TYPE = 0x7F; - @CompilationFinal(dimensions = 1) public static final byte[] I32_TYPE_ARRAY = {I32_TYPE}; + public static final int I32_TYPE = -0x01; + @CompilationFinal(dimensions = 1) public static final int[] I32_TYPE_ARRAY = {I32_TYPE}; - public static final byte I64_TYPE = 0x7E; - @CompilationFinal(dimensions = 1) public static final byte[] I64_TYPE_ARRAY = {I64_TYPE}; + public static final int I64_TYPE = -0x02; + @CompilationFinal(dimensions = 1) public static final int[] I64_TYPE_ARRAY = {I64_TYPE}; - public static final byte F32_TYPE = 0x7D; - @CompilationFinal(dimensions = 1) public static final byte[] F32_TYPE_ARRAY = {F32_TYPE}; + public static final int F32_TYPE = -0x03; + @CompilationFinal(dimensions = 1) public static final int[] F32_TYPE_ARRAY = {F32_TYPE}; - public static final byte F64_TYPE = 0x7C; - @CompilationFinal(dimensions = 1) public static final byte[] F64_TYPE_ARRAY = {F64_TYPE}; + public static final int F64_TYPE = -0x04; + @CompilationFinal(dimensions = 1) public static final int[] F64_TYPE_ARRAY = {F64_TYPE}; - public static final byte V128_TYPE = 0x7B; - @CompilationFinal(dimensions = 1) public static final byte[] V128_TYPE_ARRAY = {V128_TYPE}; + /** + * Vector Type. + */ + public static final int V128_TYPE = -0x05; + @CompilationFinal(dimensions = 1) public static final int[] V128_TYPE_ARRAY = {V128_TYPE}; /** * Reference Types. */ - public static final byte FUNCREF_TYPE = 0x70; - @CompilationFinal(dimensions = 1) public static final byte[] FUNCREF_TYPE_ARRAY = {FUNCREF_TYPE}; - public static final byte EXTERNREF_TYPE = 0x6F; - @CompilationFinal(dimensions = 1) public static final byte[] EXTERNREF_TYPE_ARRAY = {EXTERNREF_TYPE}; - public static final byte EXNREF_TYPE = 0x69; - @CompilationFinal(dimensions = 1) public static final byte[] EXNREF_TYPE_ARRAY = {EXNREF_TYPE}; + public static final int FUNC_HEAPTYPE = -0x10; + public static final int EXTERN_HEAPTYPE = -0x11; + public static final int EXN_HEAPTYPE = -0x17; + + public static final int FUNCREF_TYPE = FUNC_HEAPTYPE; + @CompilationFinal(dimensions = 1) public static final int[] FUNCREF_TYPE_ARRAY = {FUNCREF_TYPE}; + + public static final int EXTERNREF_TYPE = EXTERN_HEAPTYPE; + @CompilationFinal(dimensions = 1) public static final int[] EXTERNREF_TYPE_ARRAY = {EXTERNREF_TYPE}; + + public static final int EXNREF_TYPE = EXN_HEAPTYPE; + @CompilationFinal(dimensions = 1) public static final int[] EXNREF_TYPE_ARRAY = {EXNREF_TYPE}; + + // Implementation-specific Types. + /** + * The common supertype of all types, the universal type. + */ + public static final int TOP = -0x7e; + /** + * The common subtype of all types, the impossible type. + */ + public static final int BOT = -0x7f; + + /** + * Bytes used in the binary encoding of types. + */ + public static final byte REF_TYPE_HEADER = -0x1c; + public static final byte REF_NULL_TYPE_HEADER = -0x1d; + // -0x40 is what the void block type byte 0x40 looks like when read as a signed LEB128 value. + public static final byte VOID_BLOCK_TYPE = -0x40; + @CompilationFinal(dimensions = 1) public static final int[] VOID_TYPE_ARRAY = {}; public static final WasmType VOID = new WasmType("void"); public static final WasmType NULL = new WasmType("null"); @@ -103,45 +166,111 @@ public class WasmType implements TruffleObject { public static String toString(int valueType) { CompilerAsserts.neverPartOfCompilation(); - switch (valueType) { - case I32_TYPE: - return "i32"; - case I64_TYPE: - return "i64"; - case F32_TYPE: - return "f32"; - case F64_TYPE: - return "f64"; - case V128_TYPE: - return "v128"; - case VOID_TYPE: - return "void"; - case FUNCREF_TYPE: - return "funcref"; - case EXTERNREF_TYPE: - return "externref"; - case EXNREF_TYPE: - return "exnref"; - default: - throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(valueType)); - } + return switch (valueType) { + case I32_TYPE -> "i32"; + case I64_TYPE -> "i64"; + case F32_TYPE -> "f32"; + case F64_TYPE -> "f64"; + case V128_TYPE -> "v128"; + case TOP -> "top"; + case BOT -> "bot"; + default -> { + assert WasmType.isReferenceType(valueType); + boolean nullable = WasmType.isNullable(valueType); + yield switch (WasmType.getAbstractHeapType(valueType)) { + case FUNC_HEAPTYPE -> nullable ? "funcref" : "(ref func)"; + case EXTERN_HEAPTYPE -> nullable ? "externref" : "(ref extern)"; + case EXN_HEAPTYPE -> nullable ? "exnref" : "(ref exn)"; + default -> { + assert WasmType.isConcreteReferenceType(valueType); + StringBuilder sb = new StringBuilder(16); + sb.append("(ref "); + if (nullable) { + sb.append("null "); + } + sb.append(getTypeIndex(valueType)); + sb.append(")"); + yield sb.toString(); + } + }; + } + }; + } + + public static boolean isNumberType(int type) { + return type == I32_TYPE || type == I64_TYPE || type == F32_TYPE || type == F64_TYPE || type == BOT; + } + + public static boolean isVectorType(int type) { + return type == V128_TYPE || type == BOT; } - public static boolean isNumberType(byte type) { - return type == I32_TYPE || type == I64_TYPE || type == F32_TYPE || type == F64_TYPE || type == UNKNOWN_TYPE; + public static boolean isReferenceType(int type) { + return isConcreteReferenceType(type) || withNullable(true, type) == FUNC_HEAPTYPE || withNullable(true, type) == EXTERN_HEAPTYPE || withNullable(true, type) == EXN_HEAPTYPE || type == BOT; } - public static boolean isVectorType(byte type) { - return type == V128_TYPE || type == UNKNOWN_TYPE; + /** + * Indicates whether this is a user-defined reference type. + */ + public static boolean isConcreteReferenceType(int type) { + return type >= 0 || type == BOT; + } + + /** + * Returns the type index of this user-defined reference type. This must be used together with + * the appropriate {@link SymbolTable} to be able to expand the type's definition. + */ + public static int getTypeIndex(int type) { + assert isConcreteReferenceType(type); + return type & TYPE_VALUE_MASK; } - public static boolean isReferenceType(byte type) { - return type == FUNCREF_TYPE || type == EXTERNREF_TYPE || type == EXNREF_TYPE || type == UNKNOWN_TYPE; + /** + * Returns the "payload" of this reference type, which can be matched against the predefined + * abstract heap types (such as {@link #FUNC_HEAPTYPE} or {@link #EXTERN_HEAPTYPE}) in a switch + * statement. + */ + public static int getAbstractHeapType(int type) { + assert isReferenceType(type); + return withNullable(true, type); + } + + /** + * Indicates whether this value types admits the value {@link WasmConstant#NULL}. Can only be + * called on reference types. + */ + public static boolean isNullable(int type) { + assert isReferenceType(type); + return (type & TYPE_NULLABLE_MASK) != 0; + } + + /** + * Updates the nullability bit of this reference type. Can also be used to create a reference + * type from a heap type. + * + * @param nullable whether the resulting reference type should be nullable + * @param type a reference type or a heap type (one of the {@code *_HEAPTYPE} constants or a + * type index) + */ + public static int withNullable(boolean nullable, int type) { + if (type == BOT) { + return BOT; + } + return nullable ? type | TYPE_NULLABLE_MASK : type & ~TYPE_NULLABLE_MASK; + } + + /** + * Indicates whether this type has a default value (this is the case for all value types except + * for non-nullable reference types). Locals of such types do not have to be initialized prior + * to first access. + */ + public static boolean hasDefaultValue(int type) { + return !(isReferenceType(type) && !isNullable(type)); } - public static int getCommonValueType(byte[] types) { + public static int getCommonValueType(int[] types) { int type = 0; - for (byte resultType : types) { + for (int resultType : types) { type |= WasmType.isNumberType(resultType) ? NUM_COMMON_TYPE : 0; type |= WasmType.isVectorType(resultType) || WasmType.isReferenceType(resultType) ? OBJ_COMMON_TYPE : 0; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java index 3e9440eac6af..cb6277d78f4e 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java @@ -40,6 +40,7 @@ */ package org.graalvm.wasm.api; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmArguments; import org.graalvm.wasm.WasmConstant; import org.graalvm.wasm.WasmContext; @@ -108,7 +109,7 @@ public Object execute(VirtualFrame frame) { if (resultCount == 0) { return WasmConstant.VOID; } else if (resultCount == 1) { - byte resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, 0); + int resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, 0); return convertResult(result, resultType); } else { pushMultiValueResult(result, resultCount); @@ -125,16 +126,22 @@ public Object execute(VirtualFrame frame) { * of the correct boxed type because they're already converted on the JS side, so we only need * to unbox and can forego InteropLibrary. */ - private Object convertResult(Object result, byte resultType) throws UnsupportedMessageException { + private Object convertResult(Object result, int resultType) throws UnsupportedMessageException { CompilerAsserts.partialEvaluationConstant(resultType); + SymbolTable.ClosedValueType closedResultType = module.closedTypeOf(resultType); + CompilerAsserts.partialEvaluationConstant(closedResultType); return switch (resultType) { case WasmType.I32_TYPE -> asInt(result); case WasmType.I64_TYPE -> asLong(result); case WasmType.F32_TYPE -> asFloat(result); case WasmType.F64_TYPE -> asDouble(result); - case WasmType.V128_TYPE, WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> result; default -> { - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown result type: %d", resultType); + assert WasmType.isVectorType(resultType) || WasmType.isReferenceType(resultType); + if (!closedResultType.matchesValue(result)) { + errorBranch.enter(); + throw WasmException.create(Failure.TYPE_MISMATCH); + } + yield result; } }; } @@ -157,26 +164,24 @@ private void pushMultiValueResult(Object result, int resultCount) { final long[] primitiveMultiValueStack = multiValueStack.primitiveStack(); final Object[] objectMultiValueStack = multiValueStack.objectStack(); for (int i = 0; i < resultCount; i++) { - byte resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, i); + int resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(resultType); + SymbolTable.ClosedValueType closedResultType = module.closedTypeOf(resultType); + CompilerAsserts.partialEvaluationConstant(closedResultType); Object value = arrayInterop.readArrayElement(result, i); switch (resultType) { case WasmType.I32_TYPE -> primitiveMultiValueStack[i] = asInt(value); case WasmType.I64_TYPE -> primitiveMultiValueStack[i] = asLong(value); case WasmType.F32_TYPE -> primitiveMultiValueStack[i] = Float.floatToRawIntBits(asFloat(value)); case WasmType.F64_TYPE -> primitiveMultiValueStack[i] = Double.doubleToRawLongBits(asDouble(value)); - case WasmType.V128_TYPE -> { - if (!(value instanceof Vector128)) { + default -> { + assert WasmType.isVectorType(resultType) || WasmType.isReferenceType(resultType); + if (!closedResultType.matchesValue(value)) { errorBranch.enter(); throw WasmException.create(Failure.INVALID_TYPE_IN_MULTI_VALUE); } objectMultiValueStack[i] = value; } - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> objectMultiValueStack[i] = value; - default -> { - errorBranch.enter(); - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown result type: %d", resultType); - } } } } catch (UnsupportedMessageException | InvalidArrayIndexException e) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java index 560eedc8eb81..2cef0d2abe41 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java @@ -84,18 +84,18 @@ private static ValueType[] parseTypeString(String typesString, int start, int en } } - public static FuncType fromFunctionType(SymbolTable.FunctionType functionType) { - final byte[] paramTypes = functionType.paramTypes(); - final byte[] resultTypes = functionType.resultTypes(); + public static FuncType fromClosedFunctionType(SymbolTable.ClosedFunctionType functionType) { + final SymbolTable.ClosedValueType[] paramTypes = functionType.paramTypes(); + final SymbolTable.ClosedValueType[] resultTypes = functionType.resultTypes(); final ValueType[] params = new ValueType[paramTypes.length]; final ValueType[] results = new ValueType[resultTypes.length]; for (int i = 0; i < paramTypes.length; i++) { - params[i] = ValueType.fromByteValue(paramTypes[i]); + params[i] = ValueType.fromClosedValueType(paramTypes[i]); } for (int i = 0; i < resultTypes.length; i++) { - results[i] = ValueType.fromByteValue(resultTypes[i]); + results[i] = ValueType.fromClosedValueType(resultTypes[i]); } return new FuncType(params, results); } @@ -116,17 +116,17 @@ public int resultCount() { return results.length; } - public SymbolTable.FunctionType toFunctionType() { - final byte[] paramTypes = new byte[params.length]; - final byte[] resultTypes = new byte[results.length]; + public SymbolTable.ClosedFunctionType toClosedFunctionType() { + var paramTypes = new SymbolTable.ClosedValueType[params.length]; + var resultTypes = new SymbolTable.ClosedValueType[results.length]; for (int i = 0; i < paramTypes.length; i++) { - paramTypes[i] = params[i].byteValue(); + paramTypes[i] = params[i].asClosedValueType(); } for (int i = 0; i < resultTypes.length; i++) { - resultTypes[i] = results[i].byteValue(); + resultTypes[i] = results[i].asClosedValueType(); } - return SymbolTable.FunctionType.create(paramTypes, resultTypes); + return new SymbolTable.ClosedFunctionType(paramTypes, resultTypes); } public String toString(StringBuilder b) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java index c74567a05517..82aa7f8f9865 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java @@ -42,14 +42,10 @@ import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmArguments; -import org.graalvm.wasm.WasmConstant; import org.graalvm.wasm.WasmContext; import org.graalvm.wasm.WasmFunctionInstance; import org.graalvm.wasm.WasmLanguage; import org.graalvm.wasm.WasmType; -import org.graalvm.wasm.exception.Failure; -import org.graalvm.wasm.exception.WasmException; -import org.graalvm.wasm.exception.WasmRuntimeException; import org.graalvm.wasm.nodes.WasmIndirectCallNode; import com.oracle.truffle.api.CallTarget; @@ -76,11 +72,11 @@ public final class InteropCallAdapterNode extends RootNode { private static final int MAX_UNROLL = 32; - private final SymbolTable.FunctionType functionType; + private final SymbolTable.ClosedFunctionType functionType; private final BranchProfile errorBranch = BranchProfile.create(); @Child private WasmIndirectCallNode callNode; - public InteropCallAdapterNode(WasmLanguage language, SymbolTable.FunctionType functionType) { + public InteropCallAdapterNode(WasmLanguage language, SymbolTable.ClosedFunctionType functionType) { super(language); this.functionType = functionType; this.callNode = WasmIndirectCallNode.create(); @@ -111,7 +107,7 @@ public Object execute(VirtualFrame frame) { } private Object[] validateArguments(Object[] arguments, int offset) throws ArityException, UnsupportedTypeException { - final byte[] paramTypes = functionType.paramTypes(); + final SymbolTable.ClosedValueType[] paramTypes = functionType.paramTypes(); final int paramCount = paramTypes.length; CompilerAsserts.partialEvaluationConstant(paramCount); if (arguments.length - offset != paramCount) { @@ -128,64 +124,25 @@ private Object[] validateArguments(Object[] arguments, int offset) throws ArityE } @ExplodeLoop - private static void validateArgumentsUnroll(Object[] arguments, int offset, byte[] paramTypes, int paramCount) throws UnsupportedTypeException { + private static void validateArgumentsUnroll(Object[] arguments, int offset, SymbolTable.ClosedValueType[] paramTypes, int paramCount) throws UnsupportedTypeException { for (int i = 0; i < paramCount; i++) { validateArgument(arguments, offset, paramTypes, i); } } - private static void validateArgument(Object[] arguments, int offset, byte[] paramTypes, int i) throws UnsupportedTypeException { - byte paramType = paramTypes[i]; + private static void validateArgument(Object[] arguments, int offset, SymbolTable.ClosedValueType[] paramTypes, int i) throws UnsupportedTypeException { + SymbolTable.ClosedValueType paramType = paramTypes[i]; Object value = arguments[i + offset]; - switch (paramType) { - case WasmType.I32_TYPE -> { - if (value instanceof Integer) { - return; - } - } - case WasmType.I64_TYPE -> { - if (value instanceof Long) { - return; - } - } - case WasmType.F32_TYPE -> { - if (value instanceof Float) { - return; - } - } - case WasmType.F64_TYPE -> { - if (value instanceof Double) { - return; - } - } - case WasmType.V128_TYPE -> { - if (value instanceof Vector128) { - return; - } - } - case WasmType.FUNCREF_TYPE -> { - if (value instanceof WasmFunctionInstance || value == WasmConstant.NULL) { - return; - } - } - case WasmType.EXTERNREF_TYPE -> { - return; - } - case WasmType.EXNREF_TYPE -> { - if (value instanceof WasmRuntimeException || value == WasmConstant.NULL) { - return; - } - } - default -> throw WasmException.create(Failure.UNKNOWN_TYPE); + if (!paramType.matchesValue(value)) { + throw UnsupportedTypeException.create(arguments); } - throw UnsupportedTypeException.create(arguments); } private Object multiValueStackAsArray(WasmLanguage language) { final var multiValueStack = language.multiValueStack(); final long[] primitiveMultiValueStack = multiValueStack.primitiveStack(); final Object[] objectMultiValueStack = multiValueStack.objectStack(); - final byte[] resultTypes = functionType.resultTypes(); + final SymbolTable.ClosedValueType[] resultTypes = functionType.resultTypes(); final int resultCount = resultTypes.length; assert primitiveMultiValueStack.length >= resultCount; assert objectMultiValueStack.length >= resultCount; @@ -202,29 +159,33 @@ private Object multiValueStackAsArray(WasmLanguage language) { } @ExplodeLoop - private static void popMultiValueResultUnroll(Object[] values, long[] primitiveMultiValueStack, Object[] objectMultiValueStack, byte[] resultTypes, int resultCount) { + private static void popMultiValueResultUnroll(Object[] values, long[] primitiveMultiValueStack, Object[] objectMultiValueStack, SymbolTable.ClosedValueType[] resultTypes, int resultCount) { for (int i = 0; i < resultCount; i++) { values[i] = popMultiValueResult(primitiveMultiValueStack, objectMultiValueStack, resultTypes, i); } } - private static Object popMultiValueResult(long[] primitiveMultiValueStack, Object[] objectMultiValueStack, byte[] resultTypes, int i) { - final byte resultType = resultTypes[i]; - return switch (resultType) { - case WasmType.I32_TYPE -> (int) primitiveMultiValueStack[i]; - case WasmType.I64_TYPE -> primitiveMultiValueStack[i]; - case WasmType.F32_TYPE -> Float.intBitsToFloat((int) primitiveMultiValueStack[i]); - case WasmType.F64_TYPE -> Double.longBitsToDouble(primitiveMultiValueStack[i]); - case WasmType.V128_TYPE, WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> { + private static Object popMultiValueResult(long[] primitiveMultiValueStack, Object[] objectMultiValueStack, SymbolTable.ClosedValueType[] resultTypes, int i) { + final SymbolTable.ClosedValueType resultType = resultTypes[i]; + return switch (resultType.kind()) { + case Number -> { + SymbolTable.NumberType numberType = (SymbolTable.NumberType) resultType; + yield switch (numberType.value()) { + case WasmType.I32_TYPE -> (int) primitiveMultiValueStack[i]; + case WasmType.I64_TYPE -> primitiveMultiValueStack[i]; + case WasmType.F32_TYPE -> Float.intBitsToFloat((int) primitiveMultiValueStack[i]); + case WasmType.F64_TYPE -> Double.longBitsToDouble(primitiveMultiValueStack[i]); + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } + case Vector, Reference -> { Object obj = objectMultiValueStack[i]; objectMultiValueStack[i] = null; yield obj; } - default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL); }; } - // TODO: Do we need the 3 overrides below? @Override public String getName() { return "wasm-function-interop:" + functionType; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java index 3e565932388b..91ab0519f9d5 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java @@ -44,24 +44,22 @@ public enum TableKind { externref(WasmType.EXTERNREF_TYPE), - anyfunc(WasmType.FUNCREF_TYPE), - exnref(WasmType.EXNREF_TYPE); + anyfunc(WasmType.FUNCREF_TYPE); - private final byte byteValue; + private final int value; - TableKind(byte byteValue) { - this.byteValue = byteValue; + TableKind(int value) { + this.value = value; } - public byte byteValue() { - return byteValue; + public int value() { + return value; } - public static String toString(byte byteValue) { - return switch (byteValue) { + public static String toString(int value) { + return switch (value) { case WasmType.EXTERNREF_TYPE -> "externref"; case WasmType.FUNCREF_TYPE -> "anyfunc"; - case WasmType.EXNREF_TYPE -> "exnref"; default -> ""; }; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java index 61c707315525..8bccec71a4ca 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java @@ -40,11 +40,12 @@ */ package org.graalvm.wasm.api; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmType; -import org.graalvm.wasm.exception.Failure; -import org.graalvm.wasm.exception.WasmException; import com.oracle.truffle.api.CompilerAsserts; +import org.graalvm.wasm.exception.Failure; +import org.graalvm.wasm.exception.WasmException; public enum ValueType { i32(WasmType.I32_TYPE), @@ -56,13 +57,13 @@ public enum ValueType { externref(WasmType.EXTERNREF_TYPE), exnref(WasmType.EXNREF_TYPE); - private final byte byteValue; + private final int value; - ValueType(byte byteValue) { - this.byteValue = byteValue; + ValueType(int value) { + this.value = value; } - public static ValueType fromByteValue(byte value) { + public static ValueType fromValue(int value) { CompilerAsserts.neverPartOfCompilation(); return switch (value) { case WasmType.I32_TYPE -> i32; @@ -70,16 +71,23 @@ public static ValueType fromByteValue(byte value) { case WasmType.F32_TYPE -> f32; case WasmType.F64_TYPE -> f64; case WasmType.V128_TYPE -> v128; - case WasmType.FUNCREF_TYPE -> anyfunc; - case WasmType.EXTERNREF_TYPE -> externref; - case WasmType.EXNREF_TYPE -> exnref; - default -> - throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(value)); + default -> { + assert WasmType.isReferenceType(value); + yield switch (WasmType.getAbstractHeapType(value)) { + case WasmType.FUNC_HEAPTYPE -> anyfunc; + case WasmType.EXTERN_HEAPTYPE -> externref; + case WasmType.EXN_HEAPTYPE -> exnref; + default -> { + assert WasmType.isConcreteReferenceType(value); + yield anyfunc; + } + }; + } }; } - public byte byteValue() { - return byteValue; + public int value() { + return value; } public static boolean isNumberType(ValueType valueType) { @@ -93,4 +101,38 @@ public static boolean isVectorType(ValueType valueType) { public static boolean isReferenceType(ValueType valueType) { return valueType == anyfunc || valueType == externref || valueType == exnref; } + + public SymbolTable.ClosedValueType asClosedValueType() { + return switch (this) { + case i32 -> SymbolTable.NumberType.I32; + case i64 -> SymbolTable.NumberType.I64; + case f32 -> SymbolTable.NumberType.F32; + case f64 -> SymbolTable.NumberType.F64; + case v128 -> SymbolTable.VectorType.V128; + case anyfunc -> SymbolTable.ClosedReferenceType.FUNCREF; + case externref -> SymbolTable.ClosedReferenceType.EXTERNREF; + case exnref -> SymbolTable.ClosedReferenceType.EXNREF; + }; + } + + public static ValueType fromClosedValueType(SymbolTable.ClosedValueType closedValueType) { + return switch (closedValueType.kind()) { + case Number -> fromValue(((SymbolTable.NumberType) closedValueType).value()); + case Vector -> fromValue(((SymbolTable.VectorType) closedValueType).value()); + case Reference -> { + SymbolTable.ClosedReferenceType referenceType = (SymbolTable.ClosedReferenceType) closedValueType; + yield switch (referenceType.heapType().kind()) { + case Abstract -> { + SymbolTable.AbstractHeapType abstractHeapType = (SymbolTable.AbstractHeapType) referenceType.heapType(); + yield switch (abstractHeapType.value()) { + case WasmType.FUNC_HEAPTYPE -> anyfunc; + case WasmType.EXTERNREF_TYPE -> externref; + default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(abstractHeapType.value())); + }; + } + case Function -> anyfunc; + }; + } + }; + } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/Vector128Shape.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/Vector128Shape.java index 388bad7a939a..e9ff2fe3c71d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/Vector128Shape.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/Vector128Shape.java @@ -53,15 +53,15 @@ public enum Vector128Shape { F32X4(WasmType.F32_TYPE, 4), F64X2(WasmType.F64_TYPE, 2); - private final byte unpackedType; + private final int unpackedType; private final int dimension; - Vector128Shape(byte unpackedType, int dimension) { + Vector128Shape(int unpackedType, int dimension) { this.unpackedType = unpackedType; this.dimension = dimension; } - public byte getUnpackedType() { + public int getUnpackedType() { return unpackedType; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java index b6d6916d920c..b7b4b49fccab 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java @@ -53,6 +53,7 @@ import org.graalvm.polyglot.io.ByteSequence; import org.graalvm.wasm.EmbedderDataHolder; import org.graalvm.wasm.ImportDescriptor; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmConstant; import org.graalvm.wasm.WasmContext; import org.graalvm.wasm.WasmCustomSection; @@ -112,7 +113,7 @@ public WebAssembly(WasmContext currentContext) { addMember("mem_set_wait_callback", new Executable(WebAssembly::memSetWaitCallback)); addMember("global_alloc", new Executable(this::globalAlloc)); - addMember("global_read", new Executable(WebAssembly::globalRead)); + addMember("global_read", new Executable(this::globalRead)); addMember("global_write", new Executable(this::globalWrite)); addMember("tag_alloc", new Executable(WebAssembly::tagAlloc)); @@ -267,7 +268,7 @@ public static Sequence moduleExports(WasmModule module) } else if (f != null) { list.add(new ModuleExportDescriptor(name, ImportExportKind.function.name(), WebAssembly.functionTypeToString(f))); } else if (globalIndex != null) { - String valueType = ValueType.fromByteValue(module.globalValueType(globalIndex)).toString(); + String valueType = ValueType.fromValue(module.globalValueType(globalIndex)).toString(); String mutability = module.isGlobalMutable(globalIndex) ? "mut" : "con"; list.add(new ModuleExportDescriptor(name, ImportExportKind.global.name(), valueType + " " + mutability)); } else if (tagIndex != null) { @@ -319,7 +320,7 @@ public static List moduleImportsAsList(WasmModule module break; case ImportIdentifier.GLOBAL: final Integer globalIndex = importedGlobalDescriptors.get(descriptor); - String valueType = ValueType.fromByteValue(module.globalValueType(globalIndex)).toString(); + String valueType = ValueType.fromValue(module.globalValueType(globalIndex)).toString(); list.add(new ModuleImportDescriptor(descriptor.moduleName(), descriptor.memberName(), ImportExportKind.global.name(), valueType)); break; case ImportIdentifier.TAG: @@ -442,14 +443,14 @@ public WasmTable tableAlloc(int initial, int maximum, TableKind elemKind, Object if (Integer.compareUnsigned(initial, JS_LIMITS.tableInstanceSizeLimit()) > 0) { throw new WasmJsApiException(WasmJsApiException.Kind.RangeError, "Min table size exceeds implementation limit"); } - if (elemKind != TableKind.externref && elemKind != TableKind.anyfunc && elemKind != TableKind.exnref) { + if (elemKind != TableKind.externref && elemKind != TableKind.anyfunc) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Element type must be a reftype"); } - if (!refTypes && (elemKind == TableKind.externref || elemKind == TableKind.exnref)) { + if (!refTypes && elemKind == TableKind.externref) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Element type must be anyfunc. Enable wasm.BulkMemoryAndRefTypes to support other reference types"); } final int maxAllowedSize = minUnsigned(maximum, JS_LIMITS.tableInstanceSizeLimit()); - return new WasmTable(initial, maximum, maxAllowedSize, elemKind.byteValue(), initialValue); + return new WasmTable(initial, maximum, maxAllowedSize, elemKind.value(), initialValue); } private static Object tableGrow(Object[] args) { @@ -508,23 +509,12 @@ private Object tableWrite(Object[] args) { } public Object tableWrite(WasmTable table, int index, Object element) { - final Object elem; - if (element instanceof WasmFunctionInstance) { - elem = element; - } else if (element == WasmConstant.NULL) { - elem = WasmConstant.NULL; - } else { - if (!currentContext.getContextOptions().supportBulkMemoryAndRefTypes()) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid table element"); - } - if (table.elemType() == WasmType.FUNCREF_TYPE) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid table element"); - } - elem = element; + if (!table.closedElemType().matchesValue(element)) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid table element"); } try { - table.set(index, elem); + table.set(index, element); } catch (ArrayIndexOutOfBoundsException e) { throw new WasmJsApiException(WasmJsApiException.Kind.RangeError, "Table index out of bounds: " + e.getMessage()); } @@ -571,7 +561,7 @@ public static String functionTypeToString(WasmFunction f) { if (i != 0) { typeInfo.append(' '); } - typeInfo.append(ValueType.fromByteValue(f.paramTypeAt(i))); + typeInfo.append(ValueType.fromValue(f.paramTypeAt(i))); } typeInfo.append(')'); @@ -580,7 +570,7 @@ public static String functionTypeToString(WasmFunction f) { if (i != 0) { typeInfo.append(' '); } - typeInfo.append(ValueType.fromByteValue(f.resultTypeAt(i))); + typeInfo.append(ValueType.fromValue(f.resultTypeAt(i))); } return typeInfo.toString(); } @@ -591,7 +581,7 @@ private static String tagTypeToString(WasmModule module, int tagIndex) { assert attribute == WasmTag.Attribute.EXCEPTION; final int typeIndex = module.tagTypeIndex(tagIndex); - return FuncType.fromFunctionType(module.typeAt(typeIndex)).toString(); + return FuncType.fromClosedFunctionType(module.closedFunctionTypeAt(typeIndex)).toString(); } private static Object memAlloc(Object[] args) { @@ -663,16 +653,6 @@ public static long memGrow(WasmMemory memory, int delta) { private static Object memSetGrowCallback(Object[] args) { checkArgumentCount(args, 1); InteropLibrary lib = InteropLibrary.getUncached(); - if (args.length > 1) { - // TODO: drop this branch after JS adopts the single-argument version - if (!(args[0] instanceof WasmMemory)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be executable"); - } - if (!lib.isExecutable(args[1])) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Second argument must be executable"); - } - return memSetGrowCallback(args[1]); - } if (!lib.isExecutable(args[0])) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Argument must be executable"); } @@ -701,16 +681,6 @@ public static void invokeMemGrowCallback(WasmMemory memory) { private static Object memSetNotifyCallback(Object[] args) { checkArgumentCount(args, 1); InteropLibrary lib = InteropLibrary.getUncached(); - if (args.length > 1) { - // TODO: drop this branch after JS adopts the single-argument version - if (!(args[0] instanceof WasmMemory)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be executable"); - } - if (!lib.isExecutable(args[1])) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Second argument must be executable"); - } - return memSetNotifyCallback(args[1]); - } if (!lib.isExecutable(args[0])) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Argument must be executable"); } @@ -740,16 +710,6 @@ public static int invokeMemNotifyCallback(Node node, WasmMemory memory, long add private static Object memSetWaitCallback(Object[] args) { checkArgumentCount(args, 1); InteropLibrary lib = InteropLibrary.getUncached(); - if (args.length > 1) { - // TODO: drop this branch after JS adopts the single-argument version - if (!(args[0] instanceof WasmMemory)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be executable"); - } - if (!lib.isExecutable(args[1])) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Second argument must be executable"); - } - return memSetWaitCallback(args[1]); - } if (!lib.isExecutable(args[0])) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Argument must be executable"); } @@ -799,79 +759,71 @@ private Object globalAlloc(Object[] args) { String valueTypeString = lib.asString(args[0]); valueType = ValueType.valueOf(valueTypeString); } catch (UnsupportedMessageException e) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument (value type) must be convertible to String"); + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument (value type) must be convertible to String."); } catch (IllegalArgumentException ex) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type"); + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type."); } final boolean mutable; try { mutable = lib.asBoolean(args[1]); } catch (UnsupportedMessageException e) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument (mutable) must be convertible to boolean"); + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument (mutable) must be convertible to boolean."); } return globalAlloc(valueType, mutable, args[2]); } public WasmGlobal globalAlloc(ValueType valueType, boolean mutable, Object value) { - InteropLibrary valueInterop = InteropLibrary.getUncached(value); - try { - switch (valueType) { - case i32: - return new WasmGlobal(valueType, mutable, valueInterop.asInt(value)); - case i64: - return new WasmGlobal(valueType, mutable, valueInterop.asLong(value)); - case f32: - return new WasmGlobal(valueType, mutable, Float.floatToRawIntBits(valueInterop.asFloat(value))); - case f64: - return new WasmGlobal(valueType, mutable, Double.doubleToRawLongBits(valueInterop.asDouble(value))); - case anyfunc: - if (!refTypes || !(value == WasmConstant.NULL || value instanceof WasmFunctionInstance)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type"); - } - return new WasmGlobal(valueType, mutable, value); - case externref: - if (!refTypes) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type"); - } - return new WasmGlobal(valueType, mutable, value); - default: - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type"); + if (!valueType.asClosedValueType().matchesValue(value)) { + throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); + } + return switch (valueType) { + case i32 -> WasmGlobal.alloc32(valueType, mutable, (int) value); + case i64 -> WasmGlobal.alloc64(valueType, mutable, (long) value); + case f32 -> WasmGlobal.alloc32(valueType, mutable, Float.floatToRawIntBits((float) value)); + case f64 -> WasmGlobal.alloc64(valueType, mutable, Double.doubleToRawLongBits((double) value)); + case v128 -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.V128_VALUE_ACCESS); + case anyfunc, externref -> { + if (!refTypes) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); + } + yield WasmGlobal.allocRef(valueType, mutable, value); } - } catch (UnsupportedMessageException ex) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Cannot convert value to the specified value type"); - } + case exnref -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.EXNREF_VALUE_ACCESS); + }; } - private static Object globalRead(Object[] args) { + private Object globalRead(Object[] args) { checkArgumentCount(args, 1); if (!(args[0] instanceof WasmGlobal global)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be wasm global"); + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be wasm global."); } return globalRead(global); } - public static Object globalRead(WasmGlobal global) { - switch (global.getValueType()) { - case i32: - return global.loadAsInt(); - case i64: - return global.loadAsLong(); - case f32: - return Float.intBitsToFloat(global.loadAsInt()); - case f64: - return Double.longBitsToDouble(global.loadAsLong()); - case anyfunc: - case externref: - return global.loadAsReference(); - - } - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Incorrect internal Global type"); + public Object globalRead(WasmGlobal global) { + return switch (global.getType()) { + case WasmType.I32_TYPE -> global.loadAsInt(); + case WasmType.I64_TYPE -> global.loadAsLong(); + case WasmType.F32_TYPE -> Float.intBitsToFloat(global.loadAsInt()); + case WasmType.F64_TYPE -> Double.longBitsToDouble(global.loadAsLong()); + case WasmType.V128_TYPE -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.V128_VALUE_ACCESS); + default -> { + assert WasmType.isReferenceType(global.getType()); + if (!refTypes) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); + } + if (SymbolTable.closedTypeOf(WasmType.EXNREF_TYPE, null).isSupertypeOf(global.getClosedType())) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.EXNREF_VALUE_ACCESS); + } + yield global.loadAsReference(); + } + }; } private Object globalWrite(Object[] args) { checkArgumentCount(args, 2); if (!(args[0] instanceof WasmGlobal global)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be wasm global"); + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be wasm global."); } return globalWrite(global, args[1]); } @@ -880,48 +832,25 @@ public Object globalWrite(WasmGlobal global, Object value) { if (!global.isMutable()) { throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global is not mutable."); } - ValueType valueType = global.getValueType(); - switch (valueType) { - case i32: - if (!(value instanceof Integer)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); - } - global.storeInt((int) value); - break; - case i64: - if (!(value instanceof Long)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); - } - global.storeLong((long) value); - break; - case f32: - if (!(value instanceof Float)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); - } - global.storeInt(Float.floatToRawIntBits((float) value)); - break; - case f64: - if (!(value instanceof Double)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); - } - global.storeLong(Double.doubleToRawLongBits((double) value)); - break; - case anyfunc: + if (!global.getClosedType().matchesValue(value)) { + throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", ValueType.fromValue(global.getType()), value); + } + switch (global.getType()) { + case WasmType.I32_TYPE -> global.storeInt((int) value); + case WasmType.I64_TYPE -> global.storeLong((long) value); + case WasmType.F32_TYPE -> global.storeInt(Float.floatToRawIntBits((float) value)); + case WasmType.F64_TYPE -> global.storeLong(Double.doubleToRawLongBits((double) value)); + case WasmType.V128_TYPE -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.V128_VALUE_ACCESS); + default -> { + assert WasmType.isReferenceType(global.getType()); if (!refTypes) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled"); + throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); } - if (!(value == WasmConstant.NULL || value instanceof WasmFunctionInstance)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); - } else { - global.storeReference(value); - } - break; - case externref: - if (!refTypes) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled"); + if (SymbolTable.closedTypeOf(WasmType.EXNREF_TYPE, null).isSupertypeOf(global.getClosedType())) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.EXNREF_VALUE_ACCESS); } global.storeReference(value); - break; + } } return WasmConstant.VOID; } @@ -942,7 +871,7 @@ public static Object tagAlloc(Object[] args) { } public static WasmTag tagAlloc(FuncType type) { - return new WasmTag(type.toFunctionType()); + return new WasmTag(type.toClosedFunctionType()); } public static Object tagType(Object[] args) { @@ -958,47 +887,19 @@ public Object exnAlloc(Object[] args) { if (!(args[0] instanceof WasmTag tag)) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be a wasm tag"); } - final FuncType type = FuncType.fromFunctionType(tag.type()); + final FuncType type = FuncType.fromClosedFunctionType(tag.type()); final int paramCount = type.paramCount(); checkArgumentCount(args, paramCount + 1); final Object[] fields = new Object[paramCount]; for (int i = 0; i < paramCount; i++) { final ValueType paramType = type.paramTypeAt(i); final Object value = args[i + 1]; - switch (paramType) { - case i32: - if (!(value instanceof Integer)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); - } - break; - case i64: - if (!(value instanceof Long)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); - } - break; - case f32: - if (!(value instanceof Float)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); - } - break; - case f64: - if (!(value instanceof Double)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); - } - break; - case anyfunc: - if (!refTypes) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled"); - } - if (!(value == WasmConstant.NULL || value instanceof WasmFunctionInstance)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); - } - break; - case externref: - if (!refTypes) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled"); - } - break; + + if (!paramType.asClosedValueType().matchesValue(value)) { + throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); + } + if (ValueType.isReferenceType(paramType) && !refTypes) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); } fields[i] = value; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/ByteArrayList.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/ByteArrayList.java index 0ec417861399..f377ae6d2331 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/ByteArrayList.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/ByteArrayList.java @@ -40,9 +40,6 @@ */ package org.graalvm.wasm.collection; -import java.util.Arrays; -import java.util.NoSuchElementException; - public final class ByteArrayList { private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; @@ -60,20 +57,38 @@ public void add(byte b) { size++; } - public void push(byte b) { - add(b); - } - - public byte popBack() { - if (size() == 0) { - throw new NoSuchElementException("Cannot pop from an empty ByteArrayList."); + /** + * Add the little-endian LEB128 encoding of the signed {@code int} value to the list. + */ + public void addSignedInt32(int valueArg) { + int value = valueArg; + while (true) { + int b = value & 0x7f; + value >>= 7; + if ((value == 0 && (b & 0x40) == 0) || (value == -1 && (b & 0x40) != 0)) { + add((byte) b); + break; + } else { + add((byte) (b | 0x80)); + } } - size--; - return array[size]; } - public byte top() { - return array[size - 1]; + /** + * Add the little-endian LEB128 encoding of the unsigned {@code int} value to the list. + */ + public void addUnsignedInt32(int valueArg) { + int value = valueArg; + while (true) { + int b = value & 0x7f; + value >>>= 7; + if (value == 0) { + add((byte) b); + break; + } else { + add((byte) (b | 0x80)); + } + } } public void set(int index, byte b) { @@ -134,17 +149,4 @@ public byte[] toArray() { return EMPTY_BYTE_ARRAY; } } - - public static byte[] concat(ByteArrayList... byteArrayLists) { - int totalSize = Arrays.stream(byteArrayLists).mapToInt(ByteArrayList::size).sum(); - byte[] result = new byte[totalSize]; - int resultOffset = 0; - for (ByteArrayList byteArrayList : byteArrayLists) { - if (byteArrayList.array != null) { - System.arraycopy(byteArrayList.array, 0, result, resultOffset, byteArrayList.size); - resultOffset += byteArrayList.size; - } - } - return result; - } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/IntArrayList.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/IntArrayList.java index 460c45fff701..d54eefdb745a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/IntArrayList.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/IntArrayList.java @@ -40,6 +40,8 @@ */ package org.graalvm.wasm.collection; +import java.util.NoSuchElementException; + public final class IntArrayList { public static final int[] EMPTY_INT_ARRAY = new int[0]; @@ -57,11 +59,22 @@ public void add(int n) { offset++; } + public void push(int b) { + add(b); + } + public int popBack() { + if (size() == 0) { + throw new NoSuchElementException("Cannot pop from an empty IntArrayList."); + } offset--; return array[offset]; } + public int top() { + return array[offset - 1]; + } + public int get(int index) { return array[index]; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java index 50d177c05564..ae0d5d040127 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java @@ -333,8 +333,8 @@ public class Bytecode { public static final int REF_IS_NULL = 0xF7; public static final int REF_FUNC = 0xF8; - public static final int TABLE_GET = 0xF9; - public static final int TABLE_SET = 0xFA; + public static final int CALL_REF_U8 = 0xF9; + public static final int CALL_REF_I32 = 0xFA; public static final int MISC = 0xFB; @@ -355,26 +355,34 @@ public class Bytecode { public static final int MEMORY_INIT = 0x08; public static final int MEMORY64_INIT = 0x0A; public static final int DATA_DROP = 0x0C; - public static final int DATA_DROP_UNSAFE = 0x0D; - public static final int MEMORY_COPY = 0x0E; - public static final int MEMORY64_COPY_D32_S64 = 0x0F; - public static final int MEMORY64_COPY_D64_S32 = 0x10; - public static final int MEMORY64_COPY_D64_S64 = 0x11; - public static final int MEMORY_FILL = 0x12; - public static final int MEMORY64_FILL = 0x13; - - public static final int MEMORY64_SIZE = 0x14; - public static final int MEMORY64_GROW = 0x15; - public static final int TABLE_INIT = 0x16; - public static final int ELEM_DROP = 0x17; - public static final int TABLE_COPY = 0x18; - public static final int TABLE_GROW = 0x19; - public static final int TABLE_SIZE = 0x1A; - public static final int TABLE_FILL = 0x1B; + public static final int MEMORY_COPY = 0x0D; + public static final int MEMORY64_COPY_D32_S64 = 0x0E; + public static final int MEMORY64_COPY_D64_S32 = 0x0F; + public static final int MEMORY64_COPY_D64_S64 = 0x10; + public static final int MEMORY_FILL = 0x11; + public static final int MEMORY64_FILL = 0x12; + + public static final int MEMORY64_SIZE = 0x13; + public static final int MEMORY64_GROW = 0x14; + public static final int TABLE_INIT = 0x15; + public static final int ELEM_DROP = 0x16; + public static final int TABLE_COPY = 0x17; + public static final int TABLE_GROW = 0x18; + public static final int TABLE_SIZE = 0x19; + public static final int TABLE_FILL = 0x1A; // Exception opcodes - public static final int THROW = 0x1C; - public static final int THROW_REF = 0x1D; + public static final int THROW = 0x1B; + public static final int THROW_REF = 0x1C; + + // Typed function references opcodes + public static final int TABLE_GET = 0x1D; + public static final int TABLE_SET = 0x1E; + public static final int REF_AS_NON_NULL = 0x1F; + public static final int BR_ON_NULL_U8 = 0x20; + public static final int BR_ON_NULL_I32 = 0x21; + public static final int BR_ON_NON_NULL_U8 = 0x22; + public static final int BR_ON_NON_NULL_I32 = 0x23; // Atomic opcodes public static final int ATOMIC_I32_LOAD = 0x00; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java index 76fc7ee6ac92..9d1a247c6061 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java @@ -137,14 +137,19 @@ public class BytecodeBitEncoding { public static final int ELEM_SEG_OFFSET_ADDRESS_U16 = 0b0000_0010; public static final int ELEM_SEG_OFFSET_ADDRESS_I32 = 0b0000_0011; - public static final int ELEM_SEG_TYPE_FUNREF = 0b0001_0000; - public static final int ELEM_SEG_TYPE_EXTERNREF = 0b0010_0000; - public static final int ELEM_SEG_TYPE_EXNREF = 0b0011_0000; + public static final int ELEM_SEG_TYPE_MASK = 0b0011_0000; + public static final int ELEM_SEG_TYPE_I8 = 0b0001_0000; + public static final int ELEM_SEG_TYPE_I16 = 0b0010_0000; + public static final int ELEM_SEG_TYPE_I32 = 0b0011_0000; public static final int ELEM_SEG_MODE_VALUE = 0b0000_1111; // Elem items + public static final int ELEM_ITEM_REF_NULL_ENTRY_PREFIX = 0; + public static final int ELEM_ITEM_REF_FUNC_ENTRY_PREFIX = 1; + public static final int ELEM_ITEM_GLOBAL_GET_ENTRY_PREFIX = 2; + public static final int ELEM_ITEM_TYPE_MASK = 0b1000_0000; public static final int ELEM_ITEM_TYPE_FUNCTION_INDEX = 0b0000_0000; public static final int ELEM_ITEM_TYPE_GLOBAL_INDEX = 0b1000_0000; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java index e0a7b68bad5b..183245e9cdda 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java @@ -266,16 +266,21 @@ public final class Instructions { public static final int REF_IS_NULL = 0xD1; public static final int REF_FUNC = 0xD2; - public static final int MEMORY_INIT = 8; - public static final int DATA_DROP = 9; - public static final int MEMORY_COPY = 10; - public static final int MEMORY_FILL = 11; - public static final int TABLE_INIT = 12; - public static final int ELEM_DROP = 13; - public static final int TABLE_COPY = 14; - public static final int TABLE_GROW = 15; - public static final int TABLE_SIZE = 16; - public static final int TABLE_FILL = 17; + public static final int MEMORY_INIT = 0x08; + public static final int DATA_DROP = 0x09; + public static final int MEMORY_COPY = 0x0A; + public static final int MEMORY_FILL = 0x0B; + public static final int TABLE_INIT = 0x0C; + public static final int ELEM_DROP = 0x0D; + public static final int TABLE_COPY = 0x0E; + public static final int TABLE_GROW = 0x0F; + public static final int TABLE_SIZE = 0x10; + public static final int TABLE_FILL = 0x11; + + public static final int CALL_REF = 0x14; + public static final int REF_AS_NON_NULL = 0xD4; + public static final int BR_ON_NULL = 0xD5; + public static final int BR_ON_NON_NULL = 0xD6; public static final int ATOMIC = 0xFE; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/StackEffects.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/StackEffects.java new file mode 100644 index 000000000000..710f3bf543c2 --- /dev/null +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/StackEffects.java @@ -0,0 +1,392 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * The Universal Permissive License (UPL), Version 1.0 + * + * Subject to the condition set forth below, permission is hereby granted to any + * person obtaining a copy of this software, associated documentation and/or + * data (collectively the "Software"), free of charge and under any and all + * copyright rights in the Software, and any and all patent rights owned or + * freely licensable by each licensor hereunder covering either (i) the + * unmodified Software as contributed to or provided by such licensor, or (ii) + * the Larger Works (as defined below), to deal in both + * + * (a) the Software, and + * + * (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if + * one is included with the Software each a "Larger Work" to which the Software + * is contributed by such licensors), + * + * without restriction, including without limitation the rights to copy, create + * derivative works of, display, perform, and distribute the Software and make, + * use, sell, offer for sale, import, export, have made, and have sold the + * Software and the Larger Work(s), and to sublicense the foregoing rights on + * either these or other terms. + * + * This license is subject to the following condition: + * + * The above copyright notice and either this complete permission notice or at a + * minimum a reference to the UPL must be included in all copies or substantial + * portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +package org.graalvm.wasm.constants; + +import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; +import org.graalvm.wasm.nodes.WasmFunctionNode; + +/** + * The data stored in this class tells us how much the size of the operand stack changes after + * executing a single {@code v128} instruction. This is useful in {@link WasmFunctionNode}, since it + * allows us to execute {@code v128} instructions in a separate method without needing to return the + * new value of the stack pointer from that method. + */ +public final class StackEffects { + + private static final byte NO_EFFECT = 0; + private static final byte PUSH_1 = 1; + private static final byte POP_1 = -1; + private static final byte POP_2 = -2; + private static final byte POP_3 = -3; + private static final byte UNREACHABLE = Byte.MIN_VALUE; + + @CompilationFinal(dimensions = 1) private static final byte[] miscOpStackEffects = new byte[256]; + @CompilationFinal(dimensions = 1) private static final byte[] vectorOpStackEffects = new byte[256]; + + static { + miscOpStackEffects[Bytecode.I32_TRUNC_SAT_F32_S] = NO_EFFECT; + miscOpStackEffects[Bytecode.I32_TRUNC_SAT_F32_U] = NO_EFFECT; + miscOpStackEffects[Bytecode.I32_TRUNC_SAT_F64_S] = NO_EFFECT; + miscOpStackEffects[Bytecode.I32_TRUNC_SAT_F64_U] = NO_EFFECT; + miscOpStackEffects[Bytecode.I64_TRUNC_SAT_F32_S] = NO_EFFECT; + miscOpStackEffects[Bytecode.I64_TRUNC_SAT_F32_U] = NO_EFFECT; + miscOpStackEffects[Bytecode.I64_TRUNC_SAT_F64_S] = NO_EFFECT; + miscOpStackEffects[Bytecode.I64_TRUNC_SAT_F64_U] = NO_EFFECT; + miscOpStackEffects[Bytecode.MEMORY_INIT] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_INIT] = POP_3; + miscOpStackEffects[Bytecode.DATA_DROP] = NO_EFFECT; + miscOpStackEffects[Bytecode.MEMORY_COPY] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_COPY_D32_S64] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_COPY_D64_S32] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_COPY_D64_S64] = POP_3; + miscOpStackEffects[Bytecode.MEMORY_FILL] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_FILL] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_SIZE] = PUSH_1; + miscOpStackEffects[Bytecode.MEMORY64_GROW] = NO_EFFECT; + miscOpStackEffects[Bytecode.TABLE_INIT] = POP_3; + miscOpStackEffects[Bytecode.ELEM_DROP] = NO_EFFECT; + miscOpStackEffects[Bytecode.TABLE_COPY] = POP_3; + miscOpStackEffects[Bytecode.TABLE_GROW] = POP_1; + miscOpStackEffects[Bytecode.TABLE_SIZE] = PUSH_1; + miscOpStackEffects[Bytecode.TABLE_FILL] = POP_3; + miscOpStackEffects[Bytecode.THROW] = UNREACHABLE; // unused, because stack effect is + // followed by throw + miscOpStackEffects[Bytecode.THROW_REF] = UNREACHABLE; // unused, because stack effect is + // followed by throw + miscOpStackEffects[Bytecode.TABLE_GET] = NO_EFFECT; + miscOpStackEffects[Bytecode.TABLE_SET] = POP_2; + miscOpStackEffects[Bytecode.REF_AS_NON_NULL] = NO_EFFECT; + miscOpStackEffects[Bytecode.BR_ON_NULL_U8] = UNREACHABLE; // unused, because stack effect is + // dynamic + miscOpStackEffects[Bytecode.BR_ON_NULL_I32] = UNREACHABLE; // unused, because stack effect + // is dynamic + miscOpStackEffects[Bytecode.BR_ON_NON_NULL_U8] = UNREACHABLE; // unused, because stack + // effect is dynamic + miscOpStackEffects[Bytecode.BR_ON_NON_NULL_I32] = UNREACHABLE; // unused, because stack + // effect is dynamic + + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD8X8_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD8X8_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD16X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD16X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD32X2_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD32X2_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD8_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD16_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD32_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD64_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD32_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD64_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_STORE] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD8_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD16_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD32_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD64_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_STORE8_LANE] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_STORE16_LANE] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_STORE32_LANE] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_STORE64_LANE] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_CONST] = PUSH_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SHUFFLE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_EXTRACT_LANE_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_EXTRACT_LANE_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTRACT_LANE_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTRACT_LANE_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTRACT_LANE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTRACT_LANE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_EXTRACT_LANE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_EXTRACT_LANE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SWIZZLE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_LT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_LT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_GT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_GT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_LE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_LE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_GE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_GE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_LT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_LT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_GT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_GT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_LE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_LE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_GE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_GE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_LT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_LT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_GT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_GT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_LE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_LE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_GE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_GE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_LT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_GT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_LE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_GE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_LT] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_GT] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_LE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_GE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_LT] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_GT] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_LE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_GE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_NOT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_AND] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_ANDNOT] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_OR] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_XOR] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_BITSELECT] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_ANY_TRUE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_POPCNT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_ALL_TRUE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_BITMASK] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_NARROW_I16X8_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_NARROW_I16X8_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SHL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SHR_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SHR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_ADD_SAT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_ADD_SAT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SUB_SAT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SUB_SAT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_MIN_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_MIN_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_MAX_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_MAX_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_AVGR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTADD_PAIRWISE_I8X16_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTADD_PAIRWISE_I8X16_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_Q15MULR_SAT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_ALL_TRUE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_BITMASK] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_NARROW_I32X4_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_NARROW_I32X4_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_LOW_I8X16_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_HIGH_I8X16_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_LOW_I8X16_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_HIGH_I8X16_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SHL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SHR_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SHR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_ADD_SAT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_ADD_SAT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SUB_SAT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SUB_SAT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_MUL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_MIN_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_MIN_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_MAX_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_MAX_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_AVGR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_HIGH_I8X16_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_HIGH_I8X16_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTADD_PAIRWISE_I16X8_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTADD_PAIRWISE_I16X8_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_ALL_TRUE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_BITMASK] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_LOW_I16X8_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_HIGH_I16X8_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_LOW_I16X8_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_HIGH_I16X8_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_SHL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_SHR_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_SHR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_MUL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_MIN_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_MIN_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_MAX_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_MAX_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_DOT_I16X8_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_LOW_I16X8_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_HIGH_I16X8_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_LOW_I16X8_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_HIGH_I16X8_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_ALL_TRUE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_BITMASK] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_LOW_I32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_HIGH_I32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_LOW_I32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_HIGH_I32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_SHL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_SHR_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_SHR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_MUL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_LOW_I32X4_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_HIGH_I32X4_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_LOW_I32X4_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_HIGH_I32X4_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_CEIL] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_FLOOR] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_TRUNC] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_NEAREST] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_SQRT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_MUL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_DIV] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_MIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_MAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_PMIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_PMAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_CEIL] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_FLOOR] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_TRUNC] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_NEAREST] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_SQRT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_MUL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_DIV] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_MIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_MAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_PMIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_PMAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_CONVERT_I32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_CONVERT_I32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F64X2_S_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F64X2_U_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_CONVERT_LOW_I32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_CONVERT_LOW_I32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_DEMOTE_F64X2_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_PROMOTE_LOW_F32X4] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_RELAXED_SWIZZLE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F64X2_S_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F64X2_U_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MADD] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_NMADD] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MADD] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_NMADD] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_RELAXED_LANESELECT] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_LANESELECT] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_LANESELECT] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_RELAXED_LANESELECT] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_Q15MULR_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_DOT_I8X16_I7X16_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_DOT_I8X16_I7X16_ADD_S] = POP_2; + } + + private StackEffects() { + } + + /** + * Indicates by how much the stack grows (positive return value) or shrinks (negative return + * value) after executing the {@link Bytecode#MISC}-prefixed bytecode {@code miscOpcode}. + * + * @param miscOpcode the {@code MISC} bytecode being executed + * @return the difference between the stack size after executing the bytecode and before + * executing the bytecode + */ + public static byte getMiscOpStackEffect(int miscOpcode) { + assert miscOpcode < 256; + return miscOpStackEffects[miscOpcode]; + } + + /** + * Indicates by how much the stack grows (positive return value) or shrinks (negative return + * value) after executing the {@link Bytecode#MISC}-prefixed bytecode {@code vectorOpcode}. + * + * @param vectorOpcode the {@code VECTOR} bytecode being executed + * @return the difference between the stack size after executing the bytecode and before + * executing the bytecode + */ + public static byte getVectorOpStackEffect(int vectorOpcode) { + assert vectorOpcode < 256; + return vectorOpStackEffects[vectorOpcode]; + } +} diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Vector128OpStackEffects.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Vector128OpStackEffects.java deleted file mode 100644 index d15a11f796a1..000000000000 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Vector128OpStackEffects.java +++ /dev/null @@ -1,335 +0,0 @@ -/* - * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * The Universal Permissive License (UPL), Version 1.0 - * - * Subject to the condition set forth below, permission is hereby granted to any - * person obtaining a copy of this software, associated documentation and/or - * data (collectively the "Software"), free of charge and under any and all - * copyright rights in the Software, and any and all patent rights owned or - * freely licensable by each licensor hereunder covering either (i) the - * unmodified Software as contributed to or provided by such licensor, or (ii) - * the Larger Works (as defined below), to deal in both - * - * (a) the Software, and - * - * (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if - * one is included with the Software each a "Larger Work" to which the Software - * is contributed by such licensors), - * - * without restriction, including without limitation the rights to copy, create - * derivative works of, display, perform, and distribute the Software and make, - * use, sell, offer for sale, import, export, have made, and have sold the - * Software and the Larger Work(s), and to sublicense the foregoing rights on - * either these or other terms. - * - * This license is subject to the following condition: - * - * The above copyright notice and either this complete permission notice or at a - * minimum a reference to the UPL must be included in all copies or substantial - * portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -package org.graalvm.wasm.constants; - -import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; -import org.graalvm.wasm.nodes.WasmFunctionNode; - -/** - * The data stored in this class tells us how much the size of the operand stack changes after - * executing a single {@code v128} instruction. This is useful in {@link WasmFunctionNode}, since it - * allows us to execute {@code v128} instructions in a separate method without needing to return the - * new value of the stack pointer from that method. - */ -public final class Vector128OpStackEffects { - - private static final byte NO_EFFECT = 0; - private static final byte PUSH_1 = 1; - private static final byte POP_1 = -1; - private static final byte POP_2 = -2; - - @CompilationFinal(dimensions = 1) private static final byte[] vector128OpStackEffects = new byte[256]; - - static { - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD8X8_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD8X8_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD16X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD16X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD32X2_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD32X2_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD8_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD16_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD32_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD64_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD32_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD64_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_STORE] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD8_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD16_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD32_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD64_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_STORE8_LANE] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_STORE16_LANE] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_STORE32_LANE] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_STORE64_LANE] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_CONST] = PUSH_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SHUFFLE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_EXTRACT_LANE_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_EXTRACT_LANE_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTRACT_LANE_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTRACT_LANE_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTRACT_LANE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTRACT_LANE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_EXTRACT_LANE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_EXTRACT_LANE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SWIZZLE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_LT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_LT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_GT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_GT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_LE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_LE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_GE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_GE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_LT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_LT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_GT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_GT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_LE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_LE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_GE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_GE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_LT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_LT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_GT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_GT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_LE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_LE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_GE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_GE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_LT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_GT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_LE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_GE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_LT] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_GT] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_LE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_GE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_LT] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_GT] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_LE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_GE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_NOT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_AND] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_ANDNOT] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_OR] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_XOR] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_BITSELECT] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_ANY_TRUE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_POPCNT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_ALL_TRUE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_BITMASK] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_NARROW_I16X8_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_NARROW_I16X8_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SHL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SHR_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SHR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_ADD_SAT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_ADD_SAT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SUB_SAT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SUB_SAT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_MIN_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_MIN_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_MAX_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_MAX_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_AVGR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTADD_PAIRWISE_I8X16_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTADD_PAIRWISE_I8X16_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_Q15MULR_SAT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_ALL_TRUE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_BITMASK] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_NARROW_I32X4_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_NARROW_I32X4_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_LOW_I8X16_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_HIGH_I8X16_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_LOW_I8X16_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_HIGH_I8X16_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SHL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SHR_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SHR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_ADD_SAT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_ADD_SAT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SUB_SAT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SUB_SAT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_MUL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_MIN_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_MIN_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_MAX_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_MAX_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_AVGR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_HIGH_I8X16_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_HIGH_I8X16_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTADD_PAIRWISE_I16X8_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTADD_PAIRWISE_I16X8_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_ALL_TRUE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_BITMASK] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_LOW_I16X8_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_HIGH_I16X8_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_LOW_I16X8_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_HIGH_I16X8_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_SHL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_SHR_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_SHR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_MUL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_MIN_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_MIN_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_MAX_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_MAX_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_DOT_I16X8_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_LOW_I16X8_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_HIGH_I16X8_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_LOW_I16X8_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_HIGH_I16X8_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_ALL_TRUE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_BITMASK] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_LOW_I32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_HIGH_I32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_LOW_I32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_HIGH_I32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_SHL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_SHR_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_SHR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_MUL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_LOW_I32X4_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_HIGH_I32X4_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_LOW_I32X4_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_HIGH_I32X4_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_CEIL] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_FLOOR] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_TRUNC] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_NEAREST] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_SQRT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_MUL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_DIV] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_MIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_MAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_PMIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_PMAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_CEIL] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_FLOOR] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_TRUNC] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_NEAREST] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_SQRT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_MUL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_DIV] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_MIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_MAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_PMIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_PMAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_CONVERT_I32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_CONVERT_I32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F64X2_S_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F64X2_U_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_CONVERT_LOW_I32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_CONVERT_LOW_I32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_DEMOTE_F64X2_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_PROMOTE_LOW_F32X4] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_RELAXED_SWIZZLE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F64X2_S_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F64X2_U_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MADD] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_NMADD] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MADD] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_NMADD] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_RELAXED_LANESELECT] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_LANESELECT] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_LANESELECT] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_RELAXED_LANESELECT] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_Q15MULR_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_DOT_I8X16_I7X16_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_DOT_I8X16_I7X16_ADD_S] = POP_2; - } - - private Vector128OpStackEffects() { - } - - /** - * Indicates by how much the stack grows (positive return value) or shrinks (negative return - * value) after executing the {@code v128} instruction with opcode {@code vectorOpcode}. - * - * @param vectorOpcode the {@code v128} instruction being executed - * @return the difference between the stack size after executing the instruction and before - * executing the instruction - */ - public static byte getVector128OpStackEffect(int vectorOpcode) { - assert vectorOpcode < 256; - return vector128OpStackEffects[vectorOpcode]; - } -} diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java index 11d25cc18f8f..14a28d8591cb 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java @@ -62,9 +62,11 @@ public enum Failure { DATA_COUNT_SECTION_REQUIRED(Type.MALFORMED, "data count section required"), ILLEGAL_OPCODE(Type.MALFORMED, "illegal opcode"), MALFORMED_REFERENCE_TYPE(Type.MALFORMED, "malformed reference type"), + MALFORMED_HEAP_TYPE(Type.MALFORMED, "malformed heap type"), MALFORMED_IMPORT_KIND(Type.MALFORMED, "malformed import kind"), END_OPCODE_EXPECTED(Type.MALFORMED, "END opcode expected"), UNEXPECTED_CONTENT_AFTER_LAST_SECTION(Type.MALFORMED, "unexpected content after last section"), + MALFORMED_LIMITS_FLAGS(Type.MALFORMED, "malformed limits flags"), MALFORMED_MEMOP_FLAGS(Type.MALFORMED, "malformed memop flags"), MALFORMED_CATCH(Type.MALFORMED, "malformed catch clause"), MALFORMED_TAG_ATTRIBUTE(Type.MALFORMED, "malformed tag attribute"), @@ -81,6 +83,7 @@ public enum Failure { MULTIPLE_TABLES(Type.INVALID, "multiple tables"), LOOP_INPUT(Type.INVALID, "non-empty loop input type"), UNKNOWN_LOCAL(Type.INVALID, "unknown local"), + UNINITIALIZED_LOCAL(Type.INVALID, "uninitialized local"), UNKNOWN_GLOBAL(Type.INVALID, "unknown global"), UNKNOWN_MEMORY(Type.INVALID, "unknown memory"), UNKNOWN_TABLE(Type.INVALID, "unknown table"), @@ -91,7 +94,7 @@ public enum Failure { START_FUNCTION_PARAMS(Type.INVALID, "start function"), LIMIT_MINIMUM_GREATER_THAN_MAXIMUM(Type.INVALID, "size minimum must not be greater than maximum"), DUPLICATE_EXPORT(Type.INVALID, "duplicate export name"), - IMMUTABLE_GLOBAL_WRITE(Type.INVALID, "global is immutable"), + IMMUTABLE_GLOBAL_WRITE(Type.INVALID, "immutable global"), CONSTANT_EXPRESSION_REQUIRED(Type.INVALID, "constant expression required"), LIMIT_EXCEEDED(Type.INVALID, "limit exceeded"), MEMORY_SIZE_LIMIT_EXCEEDED(Type.INVALID, "memory size must be at most 65536 pages (4GiB)"), @@ -140,11 +143,12 @@ public enum Failure { OUT_OF_BOUNDS_MEMORY_ACCESS(Type.TRAP, "out of bounds memory access"), UNALIGNED_ATOMIC(Type.TRAP, "unaligned atomic"), EXPECTED_SHARED_MEMORY(Type.TRAP, "expected shared memory"), - INDIRECT_CALL_TYPE__MISMATCH(Type.TRAP, "indirect call type mismatch"), + INDIRECT_CALL_TYPE_MISMATCH(Type.TRAP, "indirect call type mismatch"), INVALID_MULTI_VALUE_ARITY(Type.TRAP, "provided multi-value size does not match function type"), INVALID_TYPE_IN_MULTI_VALUE(Type.TRAP, "type of value in multi-value does not match the function type"), - NULL_REFERENCE(Type.TRAP, "defined element is ref.null"), + NULL_REFERENCE(Type.TRAP, "null reference"), + NULL_FUNCTION_REFERENCE(Type.TRAP, "null function reference"), OUT_OF_BOUNDS_TABLE_ACCESS(Type.TRAP, "out of bounds table access"), // GraalWasm-specific: TABLE_INSTANCE_SIZE_LIMIT_EXCEEDED(Type.TRAP, "table instance size exceeds limit"), diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/WasmJsApiException.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/WasmJsApiException.java index 20998bbde54b..45dfa81237d8 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/WasmJsApiException.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/WasmJsApiException.java @@ -51,6 +51,9 @@ */ public class WasmJsApiException extends AbstractTruffleException { + public static final String V128_VALUE_ACCESS = "Invalid value type. Accessing v128 values from JS is not allowed."; + public static final String EXNREF_VALUE_ACCESS = "Invalid value type. Accessing exnref values from JS is not allowed."; + public enum Kind { TypeError, RangeError, diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java index c6afa14d4d95..a3ce0b467feb 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java @@ -42,13 +42,12 @@ package org.graalvm.wasm.globals; import org.graalvm.wasm.EmbedderDataHolder; -import org.graalvm.wasm.WasmConstant; -import org.graalvm.wasm.WasmFunctionInstance; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmNamesObject; +import org.graalvm.wasm.WasmType; import org.graalvm.wasm.api.ValueType; import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.constants.GlobalModifier; -import org.graalvm.wasm.exception.WasmRuntimeException; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.interop.InteropLibrary; @@ -63,44 +62,56 @@ @ExportLibrary(InteropLibrary.class) public final class WasmGlobal extends EmbedderDataHolder implements TruffleObject { - private final ValueType valueType; + private final int type; private final boolean mutable; + private final SymbolTable symbolTable; private long globalValue; private Object globalObjectValue; - private WasmGlobal(ValueType valueType, boolean mutable) { - this.valueType = valueType; + private WasmGlobal(int type, boolean mutable, SymbolTable symbolTable) { + this.type = type; this.mutable = mutable; + this.symbolTable = symbolTable; } - public WasmGlobal(ValueType valueType, boolean mutable, int value) { - this(valueType, mutable, (long) value); + public static WasmGlobal alloc32(ValueType valueType, boolean mutable, int value) { + return alloc64(valueType, mutable, value); } - public WasmGlobal(ValueType valueType, boolean mutable, long value) { - this(valueType, mutable); - assert ValueType.isNumberType(getValueType()); - this.globalValue = value; + public static WasmGlobal alloc64(ValueType valueType, boolean mutable, long value) { + assert ValueType.isNumberType(valueType); + WasmGlobal result = new WasmGlobal(valueType.value(), mutable, null); + result.globalValue = value; + return result; + } + + public static WasmGlobal allocRef(ValueType valueType, boolean mutable, Object value) { + assert ValueType.isReferenceType(valueType); + WasmGlobal result = new WasmGlobal(valueType.value(), mutable, null); + result.globalObjectValue = value; + return result; } - public WasmGlobal(ValueType valueType, boolean mutable, Object value) { - this(valueType, mutable); - this.globalValue = switch (valueType) { - case i32 -> (int) value; - case i64 -> (long) value; - case f32 -> Float.floatToRawIntBits((float) value); - case f64 -> Double.doubleToRawLongBits((double) value); + public WasmGlobal(int globalIndex, SymbolTable symbolTable, Object value) { + this(symbolTable.globalValueType(globalIndex), symbolTable.isGlobalMutable(globalIndex), symbolTable); + assert symbolTable.globalExternal(globalIndex); + this.globalValue = switch (type) { + case WasmType.I32_TYPE -> (int) value; + case WasmType.I64_TYPE -> (long) value; + case WasmType.F32_TYPE -> Float.floatToRawIntBits((float) value); + case WasmType.F64_TYPE -> Double.doubleToRawLongBits((double) value); default -> 0; }; - this.globalObjectValue = switch (valueType) { - case v128, anyfunc, externref -> value; - default -> null; - }; + this.globalObjectValue = WasmType.isVectorType(type) || WasmType.isReferenceType(type) ? value : null; + } + + public int getType() { + return type; } - public ValueType getValueType() { - return valueType; + public SymbolTable.ClosedValueType getClosedType() { + return SymbolTable.closedTypeOf(getType(), symbolTable); } public boolean isMutable() { @@ -112,44 +123,44 @@ public byte getMutability() { } public int loadAsInt() { - assert ValueType.isNumberType(getValueType()); + assert WasmType.isNumberType(getType()); return (int) globalValue; } public long loadAsLong() { - assert ValueType.isNumberType(getValueType()); + assert WasmType.isNumberType(getType()); return globalValue; } public Vector128 loadAsVector128() { - assert ValueType.isVectorType(getValueType()); + assert WasmType.isVectorType(getType()); assert globalObjectValue != null; return (Vector128) globalObjectValue; } public Object loadAsReference() { - assert ValueType.isReferenceType(getValueType()); + assert WasmType.isReferenceType(getType()); assert globalObjectValue != null; return globalObjectValue; } public void storeInt(int value) { - assert ValueType.isNumberType(getValueType()); + assert WasmType.isNumberType(getType()); this.globalValue = value; } public void storeLong(long value) { - assert ValueType.isNumberType(getValueType()); + assert WasmType.isNumberType(getType()); this.globalValue = value; } public void storeVector128(Vector128 value) { - assert ValueType.isVectorType(getValueType()); + assert WasmType.isVectorType(getType()); this.globalObjectValue = value; } public void storeReference(Object value) { - assert ValueType.isReferenceType(getValueType()); + assert WasmType.isReferenceType(getType()); this.globalObjectValue = value; } @@ -173,13 +184,16 @@ Object readMember(String member) throws UnknownIdentifierException { throw UnknownIdentifierException.create(member); } assert VALUE_MEMBER.equals(member) : member; - return switch (getValueType()) { - case i32 -> loadAsInt(); - case i64 -> loadAsLong(); - case f32 -> Float.intBitsToFloat(loadAsInt()); - case f64 -> Double.longBitsToDouble(loadAsLong()); - case v128 -> loadAsVector128(); - case anyfunc, externref, exnref -> loadAsReference(); + return switch (getType()) { + case WasmType.I32_TYPE -> loadAsInt(); + case WasmType.I64_TYPE -> loadAsLong(); + case WasmType.F32_TYPE -> Float.intBitsToFloat(loadAsInt()); + case WasmType.F64_TYPE -> Double.longBitsToDouble(loadAsLong()); + case WasmType.V128_TYPE -> loadAsVector128(); + default -> { + assert WasmType.isReferenceType(getType()); + yield loadAsReference(); + } }; } @@ -206,34 +220,23 @@ void writeMember(String member, Object value, // Constant variables cannot be modified after linking. throw UnsupportedMessageException.create(); } - switch (getValueType()) { - case i32 -> storeInt(valueLibrary.asInt(value)); - case i64 -> storeLong(valueLibrary.asLong(value)); - case f32 -> storeInt(Float.floatToRawIntBits(valueLibrary.asFloat(value))); - case f64 -> storeLong(Double.doubleToRawLongBits(valueLibrary.asDouble(value))); - case v128 -> { - if (value instanceof Vector128 vector) { - storeVector128(vector); - } - throw UnsupportedMessageException.create(); - } - case anyfunc -> { - if (value == WasmConstant.NULL || value instanceof WasmFunctionInstance) { - storeReference(value); - } - throw UnsupportedMessageException.create(); - } - case externref -> { - if (value instanceof TruffleObject) { - storeReference(value); + switch (type) { + case WasmType.I32_TYPE -> storeInt(valueLibrary.asInt(value)); + case WasmType.I64_TYPE -> storeLong(valueLibrary.asLong(value)); + case WasmType.F32_TYPE -> storeInt(Float.floatToRawIntBits(valueLibrary.asFloat(value))); + case WasmType.F64_TYPE -> storeLong(Double.doubleToRawLongBits(valueLibrary.asDouble(value))); + case WasmType.V128_TYPE -> { + if (!getClosedType().matchesValue(value)) { + throw UnsupportedMessageException.create(); } - throw UnsupportedMessageException.create(); + storeVector128((Vector128) value); } - case exnref -> { - if (value == WasmConstant.NULL || value instanceof WasmRuntimeException) { - storeReference(value); + default -> { + assert WasmType.isReferenceType(type); + if (!getClosedType().matchesValue(value)) { + throw UnsupportedMessageException.create(); } - throw UnsupportedMessageException.create(); + storeReference(value); } } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index be1d4f83cd02..746babb31354 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -84,7 +84,7 @@ import org.graalvm.wasm.constants.Bytecode; import org.graalvm.wasm.constants.BytecodeBitEncoding; import org.graalvm.wasm.constants.ExceptionHandlerType; -import org.graalvm.wasm.constants.Vector128OpStackEffects; +import org.graalvm.wasm.constants.StackEffects; import org.graalvm.wasm.exception.Failure; import org.graalvm.wasm.exception.WasmException; import org.graalvm.wasm.exception.WasmRuntimeException; @@ -326,7 +326,7 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i break; } case Bytecode.LABEL_U16: { - final int value = rawPeekU16(bytecode, offset); + final int value = rawPeekU8(bytecode, offset); final int stackSize = rawPeekU8(bytecode, offset + 1); offset += 2; final int resultCount = (value & BytecodeBitEncoding.LABEL_U16_RESULT_VALUE); @@ -535,9 +535,9 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i break; } case Bytecode.CALL_INDIRECT_U8: - case Bytecode.CALL_INDIRECT_I32: { - // Extract the function object. - stackPointer--; + case Bytecode.CALL_INDIRECT_I32: + case Bytecode.CALL_REF_U8: + case Bytecode.CALL_REF_I32: { final SymbolTable symtab = module.symbolTable(); final int callNodeIndex; @@ -548,50 +548,68 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i expectedFunctionTypeIndex = rawPeekU8(bytecode, offset + 1); tableIndex = rawPeekU8(bytecode, offset + 2); offset += 3; - } else { + } else if (opcode == Bytecode.CALL_INDIRECT_I32) { callNodeIndex = rawPeekI32(bytecode, offset); expectedFunctionTypeIndex = rawPeekI32(bytecode, offset + 4); tableIndex = rawPeekI32(bytecode, offset + 8); offset += 12; + } else if (opcode == Bytecode.CALL_REF_U8) { + callNodeIndex = rawPeekU8(bytecode, offset); + expectedFunctionTypeIndex = rawPeekU8(bytecode, offset + 1); + tableIndex = -1; + offset += 2; + } else { + assert opcode == Bytecode.CALL_REF_I32; + callNodeIndex = rawPeekI32(bytecode, offset); + expectedFunctionTypeIndex = rawPeekI32(bytecode, offset + 4); + tableIndex = -1; + offset += 8; } - final WasmTable table = instance.store().tables().table(instance.tableAddress(tableIndex)); - final Object[] elements = table.elements(); - final int elementIndex = popInt(frame, stackPointer); - if (elementIndex < 0 || elementIndex >= elements.length) { - enterErrorBranch(); - throw WasmException.format(Failure.UNDEFINED_ELEMENT, this, "Element index '%d' out of table bounds.", elementIndex); - } - // Currently, table elements may only be functions. - // We can add a check here when this changes in the future. - final Object element = elements[elementIndex]; - if (element == WasmConstant.NULL) { - enterErrorBranch(); - throw WasmException.format(Failure.UNINITIALIZED_ELEMENT, this, "Table element at index %d is uninitialized.", elementIndex); - } - final WasmFunctionInstance functionInstance; - final WasmFunction function; - final CallTarget target; - final WasmContext functionInstanceContext; - if (element instanceof WasmFunctionInstance) { - functionInstance = (WasmFunctionInstance) element; - function = functionInstance.function(); - target = functionInstance.target(); - functionInstanceContext = functionInstance.context(); + + // Extract the function object. + final Object functionCandidate; + final int elementIndex; + if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { + final WasmTable table = instance.store().tables().table(instance.tableAddress(tableIndex)); + final Object[] elements = table.elements(); + elementIndex = popInt(frame, --stackPointer); + if (elementIndex < 0 || elementIndex >= elements.length) { + enterErrorBranch(); + throw WasmException.format(Failure.UNDEFINED_ELEMENT, this, "Element index '%d' out of table bounds.", elementIndex); + } + // Currently, table elements may only be functions. + // We can add a check here when this changes in the future. + functionCandidate = elements[elementIndex]; } else { - enterErrorBranch(); - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown table element type: %s", element); + assert opcode == Bytecode.CALL_REF_U8 || opcode == Bytecode.CALL_REF_I32; + functionCandidate = popReference(frame, --stackPointer); + elementIndex = -1; } - int expectedTypeEquivalenceClass = symtab.equivalenceClass(expectedFunctionTypeIndex); + if (!(functionCandidate instanceof WasmFunctionInstance functionInstance)) { + throw callIndirectNotAFunctionError(opcode, functionCandidate, elementIndex); + } + + final WasmFunction function = functionInstance.function(); + final CallTarget target = functionInstance.target(); + final WasmContext functionInstanceContext = functionInstance.context(); // Target function instance must be from the same context. assert functionInstanceContext == WasmContext.get(this); - // Validate that the target function type matches the expected type of the - // indirect call by performing an equivalence-class check. - if (expectedTypeEquivalenceClass != function.typeEquivalenceClass()) { - enterErrorBranch(); - failFunctionTypeCheck(function, expectedFunctionTypeIndex); + if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { + // Validate that the target function type matches the expected type of + // the indirect call. We first try if the types are equivalent using the + // equivalence classes... + if (symtab.equivalenceClass(expectedFunctionTypeIndex) != function.typeEquivalenceClass()) { + codeEntry.subtypingBranch(); + // If they are not equivalent, we run the full subtype matching + // procedure. + if (!symtab.closedTypeAt(expectedFunctionTypeIndex).isSupertypeOf(function.closedType())) { + enterErrorBranch(); + failFunctionTypeCheck(function, expectedFunctionTypeIndex); + } + } } // Invoke the resolved function. @@ -1497,7 +1515,7 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i final Object refType = popReference(frame, stackPointer - 1); pushInt(frame, stackPointer - 1, refType == WasmConstant.NULL ? 1 : 0); break; - case Bytecode.REF_FUNC: + case Bytecode.REF_FUNC: { final int functionIndex = rawPeekI32(bytecode, offset); final WasmFunction function = module.symbolTable().function(functionIndex); final WasmFunctionInstance functionInstance = instance.functionInstance(function); @@ -1505,160 +1523,12 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i stackPointer++; offset += 4; break; - case Bytecode.TABLE_GET: { - final int tableIndex = rawPeekI32(bytecode, offset); - table_get(instance, frame, stackPointer, tableIndex); - offset += 4; - break; - } - case Bytecode.TABLE_SET: { - final int tableIndex = rawPeekI32(bytecode, offset); - table_set(instance, frame, stackPointer, tableIndex); - stackPointer -= 2; - offset += 4; - break; } case Bytecode.MISC: { final int miscOpcode = rawPeekU8(bytecode, offset); offset++; CompilerAsserts.partialEvaluationConstant(miscOpcode); switch (miscOpcode) { - case Bytecode.I32_TRUNC_SAT_F32_S: - i32_trunc_sat_f32_s(frame, stackPointer); - break; - case Bytecode.I32_TRUNC_SAT_F32_U: - i32_trunc_sat_f32_u(frame, stackPointer); - break; - case Bytecode.I32_TRUNC_SAT_F64_S: - i32_trunc_sat_f64_s(frame, stackPointer); - break; - case Bytecode.I32_TRUNC_SAT_F64_U: - i32_trunc_sat_f64_u(frame, stackPointer); - break; - case Bytecode.I64_TRUNC_SAT_F32_S: - i64_trunc_sat_f32_s(frame, stackPointer); - break; - case Bytecode.I64_TRUNC_SAT_F32_U: - i64_trunc_sat_f32_u(frame, stackPointer); - break; - case Bytecode.I64_TRUNC_SAT_F64_S: - i64_trunc_sat_f64_s(frame, stackPointer); - break; - case Bytecode.I64_TRUNC_SAT_F64_U: - i64_trunc_sat_f64_u(frame, stackPointer); - break; - case Bytecode.MEMORY_INIT: - case Bytecode.MEMORY64_INIT: { - final int dataIndex = rawPeekI32(bytecode, offset); - final int memoryIndex = rawPeekI32(bytecode, offset + 4); - executeMemoryInit(instance, frame, stackPointer, miscOpcode, memoryIndex, dataIndex); - stackPointer -= 3; - offset += 8; - break; - } - case Bytecode.DATA_DROP: { - final int dataIndex = rawPeekI32(bytecode, offset); - data_drop(instance, dataIndex); - offset += 4; - break; - } - case Bytecode.MEMORY_COPY: - case Bytecode.MEMORY64_COPY_D64_S64: - case Bytecode.MEMORY64_COPY_D64_S32: - case Bytecode.MEMORY64_COPY_D32_S64: { - final int destMemoryIndex = rawPeekI32(bytecode, offset); - final int srcMemoryIndex = rawPeekI32(bytecode, offset + 4); - executeMemoryCopy(instance, frame, stackPointer, miscOpcode, destMemoryIndex, srcMemoryIndex); - stackPointer -= 3; - offset += 8; - break; - } - case Bytecode.MEMORY_FILL: - case Bytecode.MEMORY64_FILL: { - final int memoryIndex = rawPeekI32(bytecode, offset); - executeMemoryFill(instance, frame, stackPointer, miscOpcode, memoryIndex); - stackPointer -= 3; - offset += 4; - break; - } - case Bytecode.TABLE_INIT: { - final int elementIndex = rawPeekI32(bytecode, offset); - final int tableIndex = rawPeekI32(bytecode, offset + 4); - - final int n = popInt(frame, stackPointer - 1); - final int src = popInt(frame, stackPointer - 2); - final int dst = popInt(frame, stackPointer - 3); - table_init(instance, n, src, dst, tableIndex, elementIndex); - stackPointer -= 3; - offset += 8; - break; - } - case Bytecode.ELEM_DROP: { - final int elementIndex = rawPeekI32(bytecode, offset); - instance.dropElemInstance(elementIndex); - offset += 4; - break; - } - case Bytecode.TABLE_COPY: { - final int srcIndex = rawPeekI32(bytecode, offset); - final int dstIndex = rawPeekI32(bytecode, offset + 4); - - final int n = popInt(frame, stackPointer - 1); - final int src = popInt(frame, stackPointer - 2); - final int dst = popInt(frame, stackPointer - 3); - table_copy(instance, n, src, dst, srcIndex, dstIndex); - stackPointer -= 3; - offset += 8; - break; - } - case Bytecode.TABLE_GROW: { - final int tableIndex = rawPeekI32(bytecode, offset); - - final int n = popInt(frame, stackPointer - 1); - final Object val = popReference(frame, stackPointer - 2); - - final int res = table_grow(instance, n, val, tableIndex); - pushInt(frame, stackPointer - 2, res); - stackPointer--; - offset += 4; - break; - } - case Bytecode.TABLE_SIZE: { - final int tableIndex = rawPeekI32(bytecode, offset); - table_size(instance, frame, stackPointer, tableIndex); - stackPointer++; - offset += 4; - break; - } - case Bytecode.TABLE_FILL: { - final int tableIndex = rawPeekI32(bytecode, offset); - - final int n = popInt(frame, stackPointer - 1); - final Object val = popReference(frame, stackPointer - 2); - final int i = popInt(frame, stackPointer - 3); - table_fill(instance, n, val, i, tableIndex); - stackPointer -= 3; - offset += 4; - break; - } - case Bytecode.MEMORY64_SIZE: { - final int memoryIndex = rawPeekI32(bytecode, offset); - offset += 4; - final WasmMemory memory = memory(instance, memoryIndex); - long pageSize = memoryLib(memoryIndex).size(memory); - pushLong(frame, stackPointer, pageSize); - stackPointer++; - break; - } - case Bytecode.MEMORY64_GROW: { - final int memoryIndex = rawPeekI32(bytecode, offset); - offset += 4; - final WasmMemory memory = memory(instance, memoryIndex); - long extraSize = popLong(frame, stackPointer - 1); - long previousSize = memoryLib(memoryIndex).grow(memory, extraSize); - pushLong(frame, stackPointer - 1, previousSize); - break; - } case Bytecode.THROW: { codeEntry.exceptionBranch(); final int tagIndex = rawPeekI32(bytecode, offset); @@ -1680,8 +1550,61 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i assert exception instanceof WasmRuntimeException : "Only wasm exceptions can be thrown by throw_ref"; throw (WasmRuntimeException) exception; } - default: - throw CompilerDirectives.shouldNotReachHere(); + case Bytecode.BR_ON_NULL_U8: { + Object reference = popReference(frame, --stackPointer); + if (profileCondition(bytecode, offset + 1, reference == WasmConstant.NULL)) { + final int offsetDelta = rawPeekU8(bytecode, offset); + // BR_ON_NULL_U8 encodes the back jump value as a positive byte + // value. BR_ON_NULL_U8 + // can never perform a forward jump. + offset -= offsetDelta; + } else { + offset += 3; + pushReference(frame, stackPointer++, reference); + } + break; + } + case Bytecode.BR_ON_NULL_I32: { + Object reference = popReference(frame, --stackPointer); + if (profileCondition(bytecode, offset + 4, reference == WasmConstant.NULL)) { + final int offsetDelta = rawPeekI32(bytecode, offset); + offset += offsetDelta; + } else { + offset += 6; + pushReference(frame, stackPointer++, reference); + } + break; + } + case Bytecode.BR_ON_NON_NULL_U8: { + Object reference = popReference(frame, --stackPointer); + if (profileCondition(bytecode, offset + 1, reference != WasmConstant.NULL)) { + final int offsetDelta = rawPeekU8(bytecode, offset); + // BR_ON_NULL_U8 encodes the back jump value as a positive byte + // value. BR_ON_NULL_U8 + // can never perform a forward jump. + offset -= offsetDelta; + pushReference(frame, stackPointer++, reference); + } else { + offset += 3; + } + break; + } + case Bytecode.BR_ON_NON_NULL_I32: { + Object reference = popReference(frame, --stackPointer); + if (profileCondition(bytecode, offset + 4, reference != WasmConstant.NULL)) { + final int offsetDelta = rawPeekI32(bytecode, offset); + offset += offsetDelta; + pushReference(frame, stackPointer++, reference); + } else { + offset += 6; + } + break; + } + default: { + offset = executeMisc(instance, frame, offset, stackPointer, miscOpcode); + stackPointer += StackEffects.getMiscOpStackEffect(miscOpcode); + break; + } } break; } @@ -1717,7 +1640,7 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i offset++; CompilerAsserts.partialEvaluationConstant(vectorOpcode); offset = executeVector(instance, frame, offset, stackPointer, vectorOpcode); - stackPointer += Vector128OpStackEffects.getVector128OpStackEffect(vectorOpcode); + stackPointer += StackEffects.getVectorOpStackEffect(vectorOpcode); break; } case Bytecode.NOTIFY: { @@ -1832,11 +1755,31 @@ private int pushExceptionFieldsAndReference(VirtualFrame frame, WasmRuntimeExcep @TruffleBoundary private void failFunctionTypeCheck(WasmFunction function, int expectedFunctionTypeIndex) { - throw WasmException.format(Failure.INDIRECT_CALL_TYPE__MISMATCH, this, + throw WasmException.format(Failure.INDIRECT_CALL_TYPE_MISMATCH, this, "Actual (type %d of function %s) and expected (type %d in module %s) types differ in the indirect call.", function.typeIndex(), function.name(), expectedFunctionTypeIndex, module.name()); } + @HostCompilerDirectives.InliningCutoff + private WasmException callIndirectNotAFunctionError(int opcode, Object functionCandidate, int elementIndex) { + enterErrorBranch(); + if (functionCandidate == WasmConstant.NULL) { + if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { + throw WasmException.format(Failure.UNINITIALIZED_ELEMENT, this, "Table element at index %d is uninitialized.", elementIndex); + } else { + assert opcode == Bytecode.CALL_REF_U8 || opcode == Bytecode.CALL_REF_I32; + throw WasmException.format(Failure.NULL_FUNCTION_REFERENCE, this, "Function reference is null"); + } + } else { + if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { + throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown table element type: %s", functionCandidate); + } else { + assert opcode == Bytecode.CALL_REF_U8 || opcode == Bytecode.CALL_REF_I32; + throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown function object: %s", functionCandidate); + } + } + } + private void check(int v, int limit) { // This is a temporary hack to hoist values out of the loop. if (v >= limit) { @@ -2163,6 +2106,178 @@ private void executeMemoryFill(WasmInstance instance, VirtualFrame frame, int st memory_fill(instance, n, val, dst, memoryIndex); } + private int executeMisc(WasmInstance instance, VirtualFrame frame, int startingOffset, int startingStackPointer, int miscOpcode) { + int offset = startingOffset; + int stackPointer = startingStackPointer; + + CompilerAsserts.partialEvaluationConstant(miscOpcode); + switch (miscOpcode) { + case Bytecode.I32_TRUNC_SAT_F32_S: + i32_trunc_sat_f32_s(frame, stackPointer); + break; + case Bytecode.I32_TRUNC_SAT_F32_U: + i32_trunc_sat_f32_u(frame, stackPointer); + break; + case Bytecode.I32_TRUNC_SAT_F64_S: + i32_trunc_sat_f64_s(frame, stackPointer); + break; + case Bytecode.I32_TRUNC_SAT_F64_U: + i32_trunc_sat_f64_u(frame, stackPointer); + break; + case Bytecode.I64_TRUNC_SAT_F32_S: + i64_trunc_sat_f32_s(frame, stackPointer); + break; + case Bytecode.I64_TRUNC_SAT_F32_U: + i64_trunc_sat_f32_u(frame, stackPointer); + break; + case Bytecode.I64_TRUNC_SAT_F64_S: + i64_trunc_sat_f64_s(frame, stackPointer); + break; + case Bytecode.I64_TRUNC_SAT_F64_U: + i64_trunc_sat_f64_u(frame, stackPointer); + break; + case Bytecode.MEMORY_INIT: + case Bytecode.MEMORY64_INIT: { + final int dataIndex = rawPeekI32(bytecode, offset); + final int memoryIndex = rawPeekI32(bytecode, offset + 4); + executeMemoryInit(instance, frame, stackPointer, miscOpcode, memoryIndex, dataIndex); + stackPointer -= 3; + offset += 8; + break; + } + case Bytecode.DATA_DROP: { + final int dataIndex = rawPeekI32(bytecode, offset); + data_drop(instance, dataIndex); + offset += 4; + break; + } + case Bytecode.MEMORY_COPY: + case Bytecode.MEMORY64_COPY_D64_S64: + case Bytecode.MEMORY64_COPY_D64_S32: + case Bytecode.MEMORY64_COPY_D32_S64: { + final int destMemoryIndex = rawPeekI32(bytecode, offset); + final int srcMemoryIndex = rawPeekI32(bytecode, offset + 4); + executeMemoryCopy(instance, frame, stackPointer, miscOpcode, destMemoryIndex, srcMemoryIndex); + stackPointer -= 3; + offset += 8; + break; + } + case Bytecode.MEMORY_FILL: + case Bytecode.MEMORY64_FILL: { + final int memoryIndex = rawPeekI32(bytecode, offset); + executeMemoryFill(instance, frame, stackPointer, miscOpcode, memoryIndex); + stackPointer -= 3; + offset += 4; + break; + } + case Bytecode.TABLE_INIT: { + final int elementIndex = rawPeekI32(bytecode, offset); + final int tableIndex = rawPeekI32(bytecode, offset + 4); + + final int n = popInt(frame, stackPointer - 1); + final int src = popInt(frame, stackPointer - 2); + final int dst = popInt(frame, stackPointer - 3); + table_init(instance, n, src, dst, tableIndex, elementIndex); + stackPointer -= 3; + offset += 8; + break; + } + case Bytecode.ELEM_DROP: { + final int elementIndex = rawPeekI32(bytecode, offset); + instance.dropElemInstance(elementIndex); + offset += 4; + break; + } + case Bytecode.TABLE_COPY: { + final int srcIndex = rawPeekI32(bytecode, offset); + final int dstIndex = rawPeekI32(bytecode, offset + 4); + + final int n = popInt(frame, stackPointer - 1); + final int src = popInt(frame, stackPointer - 2); + final int dst = popInt(frame, stackPointer - 3); + table_copy(instance, n, src, dst, srcIndex, dstIndex); + stackPointer -= 3; + offset += 8; + break; + } + case Bytecode.TABLE_GROW: { + final int tableIndex = rawPeekI32(bytecode, offset); + + final int n = popInt(frame, stackPointer - 1); + final Object val = popReference(frame, stackPointer - 2); + + final int res = table_grow(instance, n, val, tableIndex); + pushInt(frame, stackPointer - 2, res); + stackPointer--; + offset += 4; + break; + } + case Bytecode.TABLE_SIZE: { + final int tableIndex = rawPeekI32(bytecode, offset); + table_size(instance, frame, stackPointer, tableIndex); + stackPointer++; + offset += 4; + break; + } + case Bytecode.TABLE_FILL: { + final int tableIndex = rawPeekI32(bytecode, offset); + + final int n = popInt(frame, stackPointer - 1); + final Object val = popReference(frame, stackPointer - 2); + final int i = popInt(frame, stackPointer - 3); + table_fill(instance, n, val, i, tableIndex); + stackPointer -= 3; + offset += 4; + break; + } + case Bytecode.MEMORY64_SIZE: { + final int memoryIndex = rawPeekI32(bytecode, offset); + offset += 4; + final WasmMemory memory = memory(instance, memoryIndex); + long pageSize = memoryLib(memoryIndex).size(memory); + pushLong(frame, stackPointer, pageSize); + stackPointer++; + break; + } + case Bytecode.MEMORY64_GROW: { + final int memoryIndex = rawPeekI32(bytecode, offset); + offset += 4; + final WasmMemory memory = memory(instance, memoryIndex); + long extraSize = popLong(frame, stackPointer - 1); + long previousSize = memoryLib(memoryIndex).grow(memory, extraSize); + pushLong(frame, stackPointer - 1, previousSize); + break; + } + case Bytecode.TABLE_GET: { + final int tableIndex = rawPeekI32(bytecode, offset); + table_get(instance, frame, stackPointer, tableIndex); + offset += 4; + break; + } + case Bytecode.TABLE_SET: { + final int tableIndex = rawPeekI32(bytecode, offset); + table_set(instance, frame, stackPointer, tableIndex); + stackPointer -= 2; + offset += 4; + break; + } + case Bytecode.REF_AS_NON_NULL: { + Object reference = popReference(frame, stackPointer - 1); + if (reference == WasmConstant.NULL) { + enterErrorBranch(); + throw WasmException.format(Failure.NULL_REFERENCE, this, "Function reference is null"); + } + pushReference(frame, stackPointer - 1, reference); + break; + } + default: + throw CompilerDirectives.shouldNotReachHere(); + } + + assert stackPointer - startingStackPointer == StackEffects.getMiscOpStackEffect(miscOpcode); + return offset; + } + private int executeAtomic(VirtualFrame frame, int stackPointer, int opcode, WasmMemory memory, WasmMemoryLibrary memoryLib, long memOffset, int indexType64) { switch (opcode) { case Bytecode.ATOMIC_NOTIFY: @@ -3221,7 +3336,7 @@ private int executeVector(WasmInstance instance, VirtualFrame frame, int startin throw CompilerDirectives.shouldNotReachHere(); } - assert stackPointer - startingStackPointer == Vector128OpStackEffects.getVector128OpStackEffect(vectorOpcode); + assert stackPointer - startingStackPointer == StackEffects.getVectorOpStackEffect(vectorOpcode); return offset; } @@ -3359,7 +3474,7 @@ private void storeVectorLane(WasmMemory memory, WasmMemoryLibrary memoryLib, int // Checkstyle: stop method name check private void global_set(WasmInstance instance, VirtualFrame frame, int stackPointer, int index) { - final byte type = module.globalValueType(index); + final int type = module.globalValueType(index); CompilerAsserts.partialEvaluationConstant(type); // For global.set, we don't need to make sure that the referenced global is // mutable. @@ -3369,60 +3484,34 @@ private void global_set(WasmInstance instance, VirtualFrame frame, int stackPoin final GlobalRegistry globals = instance.globals(); switch (type) { - case WasmType.I32_TYPE: - globals.storeInt(globalAddress, popInt(frame, stackPointer)); - break; - case WasmType.F32_TYPE: - globals.storeFloat(globalAddress, popFloat(frame, stackPointer)); - break; - case WasmType.I64_TYPE: - globals.storeLong(globalAddress, popLong(frame, stackPointer)); - break; - case WasmType.F64_TYPE: - globals.storeDouble(globalAddress, popDouble(frame, stackPointer)); - break; - case WasmType.V128_TYPE: - globals.storeVector128(globalAddress, vector128Ops().toVector128(popVector128(frame, stackPointer))); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: + case WasmType.I32_TYPE -> globals.storeInt(globalAddress, popInt(frame, stackPointer)); + case WasmType.F32_TYPE -> globals.storeFloat(globalAddress, popFloat(frame, stackPointer)); + case WasmType.I64_TYPE -> globals.storeLong(globalAddress, popLong(frame, stackPointer)); + case WasmType.F64_TYPE -> globals.storeDouble(globalAddress, popDouble(frame, stackPointer)); + case WasmType.V128_TYPE -> globals.storeVector128(globalAddress, vector128Ops().toVector128(popVector128(frame, stackPointer))); + default -> { + assert WasmType.isReferenceType(type); globals.storeReference(globalAddress, popReference(frame, stackPointer)); - break; - default: - throw WasmException.create(Failure.UNSPECIFIED_TRAP, this, "Global variable cannot have the void type."); + } } } private void global_get(WasmInstance instance, VirtualFrame frame, int stackPointer, int index) { - final byte type = module.symbolTable().globalValueType(index); + final int type = module.symbolTable().globalValueType(index); CompilerAsserts.partialEvaluationConstant(type); final int globalAddress = module.symbolTable().globalAddress(index); final GlobalRegistry globals = instance.globals(); switch (type) { - case WasmType.I32_TYPE: - pushInt(frame, stackPointer, globals.loadAsInt(globalAddress)); - break; - case WasmType.F32_TYPE: - pushFloat(frame, stackPointer, globals.loadAsFloat(globalAddress)); - break; - case WasmType.I64_TYPE: - pushLong(frame, stackPointer, globals.loadAsLong(globalAddress)); - break; - case WasmType.F64_TYPE: - pushDouble(frame, stackPointer, globals.loadAsDouble(globalAddress)); - break; - case WasmType.V128_TYPE: - pushVector128(frame, stackPointer, vector128Ops().fromVector128(globals.loadAsVector128(globalAddress))); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: + case WasmType.I32_TYPE -> pushInt(frame, stackPointer, globals.loadAsInt(globalAddress)); + case WasmType.F32_TYPE -> pushFloat(frame, stackPointer, globals.loadAsFloat(globalAddress)); + case WasmType.I64_TYPE -> pushLong(frame, stackPointer, globals.loadAsLong(globalAddress)); + case WasmType.F64_TYPE -> pushDouble(frame, stackPointer, globals.loadAsDouble(globalAddress)); + case WasmType.V128_TYPE -> pushVector128(frame, stackPointer, vector128Ops().fromVector128(globals.loadAsVector128(globalAddress))); + default -> { + assert WasmType.isReferenceType(type); pushReference(frame, stackPointer, globals.loadAsReference(globalAddress)); - break; - default: - throw WasmException.create(Failure.UNSPECIFIED_TRAP, this, "Global variable cannot have the void type."); + } } } @@ -4547,7 +4636,7 @@ private Object[] createArgumentsForCall(VirtualFrame frame, int functionTypeInde int stackPointer = stackPointerOffset; for (int i = numArgs - 1; i >= 0; --i) { stackPointer--; - byte type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); + int type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(type); Object arg = switch (type) { case WasmType.I32_TYPE -> popInt(frame, stackPointer); @@ -4555,8 +4644,10 @@ private Object[] createArgumentsForCall(VirtualFrame frame, int functionTypeInde case WasmType.F32_TYPE -> popFloat(frame, stackPointer); case WasmType.F64_TYPE -> popDouble(frame, stackPointer); case WasmType.V128_TYPE -> vector128Ops().toVector128(popVector128(frame, stackPointer)); - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> popReference(frame, stackPointer); - default -> throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown type: %d", type); + default -> { + assert WasmType.isReferenceType(type); + yield popReference(frame, stackPointer); + } }; WasmArguments.setArgument(args, i, arg); } @@ -4577,7 +4668,7 @@ private Object[] createFieldsForException(VirtualFrame frame, int functionTypeIn int stackPointer = stackPointerOffset; for (int i = numFields - 1; i >= 0; --i) { stackPointer--; - byte type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); + int type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(type); final Object arg = switch (type) { case WasmType.I32_TYPE -> popInt(frame, stackPointer); @@ -4585,8 +4676,10 @@ private Object[] createFieldsForException(VirtualFrame frame, int functionTypeIn case WasmType.F32_TYPE -> popFloat(frame, stackPointer); case WasmType.F64_TYPE -> popDouble(frame, stackPointer); case WasmType.V128_TYPE -> vector128Ops().toVector128(popVector128(frame, stackPointer)); - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> popReference(frame, stackPointer); - default -> throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown type: %d", type); + default -> { + assert WasmType.isReferenceType(type); + yield popReference(frame, stackPointer); + } }; fields[i] = arg; } @@ -4601,7 +4694,7 @@ private void pushExceptionFields(VirtualFrame frame, WasmRuntimeException e, int final Object[] fields = e.fields(); int stackPointer = stackPointerOffset; for (int i = 0; i < numFields; i++) { - byte type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); + int type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(type); switch (type) { case WasmType.I32_TYPE -> pushInt(frame, stackPointer, (int) fields[i]); @@ -4609,8 +4702,10 @@ private void pushExceptionFields(VirtualFrame frame, WasmRuntimeException e, int case WasmType.F32_TYPE -> pushFloat(frame, stackPointer, (float) fields[i]); case WasmType.F64_TYPE -> pushDouble(frame, stackPointer, (double) fields[i]); case WasmType.V128_TYPE -> pushVector128(frame, stackPointer, vector128Ops().fromVector128((Vector128) fields[i])); - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> pushReference(frame, stackPointer, fields[i]); - default -> throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown type: %d", type); + default -> { + assert WasmType.isReferenceType(type); + pushReference(frame, stackPointer, fields[i]); + } } stackPointer++; } @@ -4787,7 +4882,7 @@ private int pushDirectCallResult(VirtualFrame frame, int stackPointer, WasmFunct if (resultCount == 0) { return stackPointer; } else if (resultCount == 1) { - final byte resultType = function.resultTypeAt(0); + final int resultType = function.resultTypeAt(0); pushResult(frame, stackPointer, resultType, result); return stackPointer + 1; } else { @@ -4803,7 +4898,7 @@ private int pushIndirectCallResult(VirtualFrame frame, int stackPointer, int exp if (resultCount == 0) { return stackPointer; } else if (resultCount == 1) { - final byte resultType = module.symbolTable().functionTypeResultTypeAt(expectedFunctionTypeIndex, 0); + final int resultType = module.symbolTable().functionTypeResultTypeAt(expectedFunctionTypeIndex, 0); pushResult(frame, stackPointer, resultType, result); return stackPointer + 1; } else { @@ -4812,7 +4907,7 @@ private int pushIndirectCallResult(VirtualFrame frame, int stackPointer, int exp } } - private void pushResult(VirtualFrame frame, int stackPointer, byte resultType, Object result) { + private void pushResult(VirtualFrame frame, int stackPointer, int resultType, Object result) { CompilerAsserts.partialEvaluationConstant(resultType); switch (resultType) { case WasmType.I32_TYPE -> pushInt(frame, stackPointer, (int) result); @@ -4820,9 +4915,9 @@ private void pushResult(VirtualFrame frame, int stackPointer, byte resultType, O case WasmType.F32_TYPE -> pushFloat(frame, stackPointer, (float) result); case WasmType.F64_TYPE -> pushDouble(frame, stackPointer, (double) result); case WasmType.V128_TYPE -> pushVector128(frame, stackPointer, vector128Ops().fromVector128((Vector128) result)); - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> pushReference(frame, stackPointer, result); default -> { - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown result type: %d", resultType); + assert WasmType.isReferenceType(resultType); + pushReference(frame, stackPointer, result); } } } @@ -4845,7 +4940,7 @@ private void extractMultiValueResult(VirtualFrame frame, int stackPointer, Objec final long[] primitiveMultiValueStack = multiValueStack.primitiveStack(); final Object[] objectMultiValueStack = multiValueStack.objectStack(); for (int i = 0; i < resultCount; i++) { - final byte resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, i); + final int resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(resultType); switch (resultType) { case WasmType.I32_TYPE -> pushInt(frame, stackPointer + i, (int) primitiveMultiValueStack[i]); @@ -4856,14 +4951,11 @@ private void extractMultiValueResult(VirtualFrame frame, int stackPointer, Objec pushVector128(frame, stackPointer + i, vector128Ops().fromVector128((Vector128) objectMultiValueStack[i])); objectMultiValueStack[i] = null; } - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> { + default -> { + assert WasmType.isReferenceType(resultType); pushReference(frame, stackPointer + i, objectMultiValueStack[i]); objectMultiValueStack[i] = null; } - default -> { - enterErrorBranch(); - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown result type: %d", resultType); - } } } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionRootNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionRootNode.java index d1a8bf480031..01000a040a76 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionRootNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionRootNode.java @@ -106,7 +106,7 @@ void enterErrorBranch() { codeEntry.errorBranch(); } - byte resultType(int index) { + int resultType(int index) { return codeEntry.resultType(index); } @@ -114,7 +114,7 @@ int paramCount() { return module().symbolTable().function(codeEntry.functionIndex()).paramCount(); } - byte localType(int index) { + int localType(int index) { return codeEntry.localType(index); } @@ -154,28 +154,19 @@ public Object executeWithInstance(VirtualFrame frame, WasmInstance instance) { if (resultCount == 0) { return WasmConstant.VOID; } else if (resultCount == 1) { - final byte resultType = resultType(0); + final int resultType = resultType(0); CompilerAsserts.partialEvaluationConstant(resultType); - switch (resultType) { - case WasmType.VOID_TYPE: - return WasmConstant.VOID; - case WasmType.I32_TYPE: - return popInt(frame, localCount); - case WasmType.I64_TYPE: - return popLong(frame, localCount); - case WasmType.F32_TYPE: - return popFloat(frame, localCount); - case WasmType.F64_TYPE: - return popDouble(frame, localCount); - case WasmType.V128_TYPE: - return Vector128Ops.SINGLETON_IMPLEMENTATION.toVector128(popVector128(frame, localCount)); - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: - return popReference(frame, localCount); - default: - throw WasmException.format(Failure.UNSPECIFIED_INTERNAL, this, "Unknown result type: %d", resultType); - } + return switch (resultType) { + case WasmType.I32_TYPE -> popInt(frame, localCount); + case WasmType.I64_TYPE -> popLong(frame, localCount); + case WasmType.F32_TYPE -> popFloat(frame, localCount); + case WasmType.F64_TYPE -> popDouble(frame, localCount); + case WasmType.V128_TYPE -> Vector128Ops.SINGLETON_IMPLEMENTATION.toVector128(popVector128(frame, localCount)); + default -> { + assert WasmType.isReferenceType(resultType); + yield popReference(frame, localCount); + } + }; } else { moveResultValuesToMultiValueStack(frame, resultCount, localCount); return WasmConstant.MULTI_VALUE; @@ -192,28 +183,15 @@ private void moveResultValuesToMultiValueStack(VirtualFrame frame, int resultCou final int resultType = resultType(i); CompilerAsserts.partialEvaluationConstant(resultType); switch (resultType) { - case WasmType.I32_TYPE: - primitiveMultiValueStack[i] = popInt(frame, localCount + i); - break; - case WasmType.I64_TYPE: - primitiveMultiValueStack[i] = popLong(frame, localCount + i); - break; - case WasmType.F32_TYPE: - primitiveMultiValueStack[i] = Float.floatToRawIntBits(popFloat(frame, localCount + i)); - break; - case WasmType.F64_TYPE: - primitiveMultiValueStack[i] = Double.doubleToRawLongBits(popDouble(frame, localCount + i)); - break; - case WasmType.V128_TYPE: - objectMultiValueStack[i] = Vector128Ops.SINGLETON_IMPLEMENTATION.toVector128(popVector128(frame, localCount + i)); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: + case WasmType.I32_TYPE -> primitiveMultiValueStack[i] = popInt(frame, localCount + i); + case WasmType.I64_TYPE -> primitiveMultiValueStack[i] = popLong(frame, localCount + i); + case WasmType.F32_TYPE -> primitiveMultiValueStack[i] = Float.floatToRawIntBits(popFloat(frame, localCount + i)); + case WasmType.F64_TYPE -> primitiveMultiValueStack[i] = Double.doubleToRawLongBits(popDouble(frame, localCount + i)); + case WasmType.V128_TYPE -> objectMultiValueStack[i] = Vector128Ops.SINGLETON_IMPLEMENTATION.toVector128(popVector128(frame, localCount + i)); + default -> { + assert WasmType.isReferenceType(resultType); objectMultiValueStack[i] = popReference(frame, localCount + i); - break; - default: - throw WasmException.format(Failure.UNSPECIFIED_INTERNAL, this, "Unknown result type: %d", resultType); + } } } } @@ -225,28 +203,17 @@ private void moveArgumentsToLocals(VirtualFrame frame) { assert WasmArguments.getArgumentCount(args) == paramCount : "Expected number of params " + paramCount + ", actual " + WasmArguments.getArgumentCount(args); for (int i = 0; i != paramCount; ++i) { final Object arg = WasmArguments.getArgument(args, i); - byte type = localType(i); + int type = localType(i); switch (type) { - case WasmType.I32_TYPE: - pushInt(frame, i, (int) arg); - break; - case WasmType.I64_TYPE: - pushLong(frame, i, (long) arg); - break; - case WasmType.F32_TYPE: - pushFloat(frame, i, (float) arg); - break; - case WasmType.F64_TYPE: - pushDouble(frame, i, (double) arg); - break; - case WasmType.V128_TYPE: - pushVector128(frame, i, Vector128Ops.SINGLETON_IMPLEMENTATION.fromVector128((Vector128) arg)); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: + case WasmType.I32_TYPE -> pushInt(frame, i, (int) arg); + case WasmType.I64_TYPE -> pushLong(frame, i, (long) arg); + case WasmType.F32_TYPE -> pushFloat(frame, i, (float) arg); + case WasmType.F64_TYPE -> pushDouble(frame, i, (double) arg); + case WasmType.V128_TYPE -> pushVector128(frame, i, Vector128Ops.SINGLETON_IMPLEMENTATION.fromVector128((Vector128) arg)); + default -> { + assert WasmType.isReferenceType(type); pushReference(frame, i, arg); - break; + } } } } @@ -255,28 +222,17 @@ private void moveArgumentsToLocals(VirtualFrame frame) { private void initializeLocals(VirtualFrame frame) { int paramCount = paramCount(); for (int i = paramCount; i != localCount(); ++i) { - byte type = localType(i); + int type = localType(i); switch (type) { - case WasmType.I32_TYPE: - pushInt(frame, i, 0); - break; - case WasmType.I64_TYPE: - pushLong(frame, i, 0L); - break; - case WasmType.F32_TYPE: - pushFloat(frame, i, 0F); - break; - case WasmType.F64_TYPE: - pushDouble(frame, i, 0D); - break; - case WasmType.V128_TYPE: - pushVector128(frame, i, Vector128Ops.SINGLETON_IMPLEMENTATION.fromVector128(Vector128.ZERO)); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: + case WasmType.I32_TYPE -> pushInt(frame, i, 0); + case WasmType.I64_TYPE -> pushLong(frame, i, 0L); + case WasmType.F32_TYPE -> pushFloat(frame, i, 0F); + case WasmType.F64_TYPE -> pushDouble(frame, i, 0D); + case WasmType.V128_TYPE -> pushVector128(frame, i, Vector128Ops.SINGLETON_IMPLEMENTATION.fromVector128(Vector128.ZERO)); + default -> { + WasmType.isReferenceType(type); pushReference(frame, i, WasmConstant.NULL); - break; + } } } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmRootNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmRootNode.java index 0d90a6ccf5ef..351c96467a27 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmRootNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmRootNode.java @@ -84,7 +84,7 @@ public final void tryInitialize(WasmInstance instance) { // We want to ensure that linking always precedes the running of the WebAssembly code. // This linking should be as late as possible, because a WebAssembly context should // be able to parse multiple modules before the code gets run. - if (!instance.isLinkCompletedFastPath()) { + if (getContext().getContextOptions().evalReturnsInstance() && !instance.isLinkCompletedFastPath()) { nonLinkedProfile.enter(); instance.store().linker().tryLinkFastPath(instance); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java index cafab53c9de7..3845cc7e4d7a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java @@ -56,7 +56,8 @@ import org.graalvm.wasm.WasmInstance; import org.graalvm.wasm.WasmModule; import org.graalvm.wasm.WasmStore; -import org.graalvm.wasm.collection.ByteArrayList; +import org.graalvm.wasm.WasmType; +import org.graalvm.wasm.collection.IntArrayList; import org.graalvm.wasm.constants.Bytecode; import org.graalvm.wasm.constants.BytecodeBitEncoding; import org.graalvm.wasm.constants.SegmentMode; @@ -223,14 +224,32 @@ public static void resetMemoryState(WasmModule module, WasmInstance instance) { */ public static void resetTableState(WasmStore store, WasmModule module, WasmInstance instance) { final byte[] bytecode = module.bytecode(); + for (int tableIndex = 0; tableIndex < module.tableCount(); tableIndex++) { + if (module.tableInitialValue(tableIndex) != null) { + Linker.initializeTable(instance, tableIndex, module.tableInitialValue(tableIndex)); + } else if (module.tableInitializerBytecode(tableIndex) != null) { + Linker.initializeTable(instance, tableIndex, Linker.evalConstantExpression(instance, module.tableInitializerBytecode(tableIndex))); + } + } for (int i = 0; i < module.elemInstanceCount(); i++) { final int elemOffset = module.elemInstanceOffset(i); final int flags = bytecode[elemOffset]; - final int typeAndMode = bytecode[elemOffset + 1]; + final int typeLengthAndMode = bytecode[elemOffset + 1]; int effectiveOffset = elemOffset + 2; - final int elemMode = typeAndMode & BytecodeBitEncoding.ELEM_SEG_MODE_VALUE; + final int elemMode = typeLengthAndMode & BytecodeBitEncoding.ELEM_SEG_MODE_VALUE; + switch (typeLengthAndMode & BytecodeBitEncoding.ELEM_SEG_TYPE_MASK) { + case BytecodeBitEncoding.ELEM_SEG_TYPE_I8: + effectiveOffset++; + break; + case BytecodeBitEncoding.ELEM_SEG_TYPE_I16: + effectiveOffset += 2; + break; + case BytecodeBitEncoding.ELEM_SEG_TYPE_I32: + effectiveOffset += 4; + break; + } final int elemCount; switch (flags & BytecodeBitEncoding.ELEM_SEG_COUNT_MASK) { case BytecodeBitEncoding.ELEM_SEG_COUNT_U8: @@ -406,26 +425,26 @@ public static CodeEntry readCodeEntry(WasmModule module, byte[] bytecode, int co default: throw CompilerDirectives.shouldNotReachHere(); } - final byte[] locals; + final int[] locals; if ((flags & BytecodeBitEncoding.CODE_ENTRY_LOCALS_FLAG) != 0) { - ByteArrayList localsList = new ByteArrayList(); - for (; bytecode[effectiveOffset] != 0; effectiveOffset++) { - localsList.add(bytecode[effectiveOffset]); + IntArrayList localsList = new IntArrayList(); + for (; bytecode[effectiveOffset] != 0; effectiveOffset += 4) { + localsList.add(BinaryStreamParser.peek4(bytecode, effectiveOffset)); } effectiveOffset++; locals = localsList.toArray(); } else { - locals = Bytecode.EMPTY_BYTES; + locals = WasmType.VOID_TYPE_ARRAY; } - final byte[] results; + final int[] results; if ((flags & BytecodeBitEncoding.CODE_ENTRY_RESULT_FLAG) != 0) { - ByteArrayList resultsList = new ByteArrayList(); - for (; bytecode[effectiveOffset] != 0; effectiveOffset++) { - resultsList.add(bytecode[effectiveOffset]); + IntArrayList resultsList = new IntArrayList(); + for (; bytecode[effectiveOffset] != 0; effectiveOffset += 4) { + resultsList.add(BinaryStreamParser.peek4(bytecode, effectiveOffset)); } results = resultsList.toArray(); } else { - results = Bytecode.EMPTY_BYTES; + results = WasmType.VOID_TYPE_ARRAY; } final int endOffset; if (exceptionTableOffset == BytecodeBitEncoding.INVALID_EXCEPTION_TABLE_OFFSET) { @@ -464,12 +483,14 @@ private static List readCallNodes(byte[] bytecode, int startOffset, in offset += 8; break; } - case Bytecode.CALL_INDIRECT_U8: { + case Bytecode.CALL_INDIRECT_U8: + case Bytecode.CALL_REF_U8: { callNodes.add(new CallNode(originalOffset)); offset += 3; break; } - case Bytecode.CALL_INDIRECT_I32: { + case Bytecode.CALL_INDIRECT_I32: + case Bytecode.CALL_REF_I32: { callNodes.add(new CallNode(originalOffset)); offset += 12; break; @@ -703,9 +724,7 @@ private static List readCallNodes(byte[] bytecode, int startOffset, in case Bytecode.I64_STORE_32_I32: case Bytecode.I32_CONST_I32: case Bytecode.F32_CONST: - case Bytecode.REF_FUNC: - case Bytecode.TABLE_GET: - case Bytecode.TABLE_SET: { + case Bytecode.REF_FUNC: { offset += 4; break; } @@ -789,20 +808,31 @@ private static List readCallNodes(byte[] bytecode, int startOffset, in case Bytecode.THROW_REF: { break; } + case Bytecode.BR_ON_NULL_U8: + case Bytecode.BR_ON_NON_NULL_U8: { + offset += 3; + break; + } case Bytecode.MEMORY_FILL: case Bytecode.MEMORY64_FILL: case Bytecode.MEMORY64_SIZE: case Bytecode.MEMORY64_GROW: case Bytecode.DATA_DROP: - case Bytecode.DATA_DROP_UNSAFE: case Bytecode.ELEM_DROP: case Bytecode.TABLE_GROW: case Bytecode.TABLE_SIZE: case Bytecode.TABLE_FILL: - case Bytecode.THROW: { + case Bytecode.THROW: + case Bytecode.TABLE_GET: + case Bytecode.TABLE_SET: { offset += 4; break; } + case Bytecode.BR_ON_NULL_I32: + case Bytecode.BR_ON_NON_NULL_I32: { + offset += 6; + break; + } case Bytecode.MEMORY_INIT: case Bytecode.MEMORY64_INIT: case Bytecode.MEMORY_COPY: diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java index 6c000c36d1ca..af5b59225302 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java @@ -47,8 +47,6 @@ import org.graalvm.wasm.constants.BytecodeBitEncoding; import org.graalvm.wasm.constants.SegmentMode; -import com.oracle.truffle.api.CompilerDirectives; - /** * A data structure for generating the GraalWasm runtime bytecode. */ @@ -74,6 +72,10 @@ private static boolean fitsIntoUnsignedByte(long value) { return Long.compareUnsigned(value, 255) <= 0; } + private static boolean fitsIntoSignedShort(int value) { + return value >= Short.MIN_VALUE && value <= Short.MAX_VALUE; + } + private static boolean fitsIntoUnsignedShort(int value) { return Integer.compareUnsigned(value, 65535) <= 0; } @@ -297,7 +299,7 @@ public void addExtendedMemoryInstruction(int opcode, int memoryIndex, long offse * @param resultCount The number of results of the block. * @param stackSize The stack size at the start of the block. * @param commonResultType The most common result type of the result types of the block. See - * {@link WasmType#getCommonValueType(byte[])}. + * {@link WasmType#getCommonValueType(int[])}. * @return The location of the label in the bytecode. */ public int addLabel(int resultCount, int stackSize, int commonResultType) { @@ -342,7 +344,7 @@ public int addLabel(int resultCount, int stackSize, int commonResultType) { * @param resultCount The number of results of the loop. * @param stackSize The stack size at the start of the loop. * @param commonResultType The most common result type of the result types of the loop. See - * {@link WasmType#getCommonValueType(byte[])}. + * {@link WasmType#getCommonValueType(int[])}. * @return The location of the loop label in the bytecode. */ public int addLoopLabel(int resultCount, int stackSize, int commonResultType) { @@ -354,7 +356,7 @@ public int addLoopLabel(int resultCount, int stackSize, int commonResultType) { /** * Adds an if opcode to the bytecode and reserves an i32 value for the jump offset and a 2-byte * profile. - * + * * @return The location of the jump offset to be patched later. (see * {@link #patchLocation(int, int)}. */ @@ -368,77 +370,82 @@ public int addIfLocation() { return location; } - /** - * Adds a branch opcode to the bytecode. If the negative jump offset fits into a u8 value, a - * br_u8 and u8 jump offset is added (The jump offset is encoded as a positive value). - * Otherwise, a br_i32 and i32 jump offset is added. - * - * @param location The target location of the branch. - */ - public void addBranch(int location) { - assert location >= 0; - final int relativeOffset = location - (location() + 1); - if (relativeOffset <= 0 && relativeOffset >= -255) { - add1(Bytecode.BR_U8); - add1(-relativeOffset); - } else { - add1(Bytecode.BR_I32); - add4(relativeOffset); + public enum BranchOp { + BR(op(Bytecode.BR_U8), op(Bytecode.BR_I32), false), + BR_IF(op(Bytecode.BR_IF_U8), op(Bytecode.BR_IF_I32), true), + BR_ON_NULL(miscOp(Bytecode.BR_ON_NULL_U8), miscOp(Bytecode.BR_ON_NULL_I32), true), + BR_ON_NON_NULL(miscOp(Bytecode.BR_ON_NON_NULL_U8), miscOp(Bytecode.BR_ON_NON_NULL_I32), true); + + private final byte[] opcodesU8; + private final byte[] opcodesI32; + private final boolean profiled; + + BranchOp(byte[] opcodesU8, byte[] opcodesI32, boolean profiled) { + this.opcodesU8 = opcodesU8; + this.opcodesI32 = opcodesI32; + this.profiled = profiled; } - } - /** - * Adds a br_i32 instruction to the bytecode and reserves an i32 value for the jump offset. - * - * @return The location of the jump offset to be patched later. (see - * {@link #patchLocation(int, int)}). - */ - public int addBranchLocation() { - add1(Bytecode.BR_I32); - final int location = location(); - add4(0); - return location; + public void emitOpcodesU8(RuntimeBytecodeGen bytecode) { + bytecode.addBytes(opcodesU8, 0, opcodesU8.length); + } + + public void emitOpcodesI32(RuntimeBytecodeGen bytecode) { + bytecode.addBytes(opcodesI32, 0, opcodesI32.length); + } + + public void emitProfile(RuntimeBytecodeGen bytecode) { + if (profiled) { + bytecode.addProfile(); + } + } + + private static byte[] op(int opcode) { + return new byte[]{(byte) opcode}; + } + + private static byte[] miscOp(int opcode) { + return new byte[]{(byte) Bytecode.MISC, (byte) opcode}; + } } /** - * Adds a conditional branch opcode to the bytecode. If the jump offset fits into a signed i8 - * value, a br_if_i8 and i8 jump offset is added. Otherwise, a br_if_i32 and i32 jump offset is - * added. In both cases, a profile with a size of 2-byte is added. - * + * Adds a branch opcode to the bytecode. If the jump offset fits into a signed i8 value, + * {@code opcodesU8} and i8 jump offset is added. Otherwise, {@code opcodesI32} and i32 jump + * offset is added. In both cases, a profile with a size of 2-byte is added. + * * @param location The target location of the branch. */ - public void addBranchIf(int location) { + public void addBranch(int location, BranchOp branchOp) { assert location >= 0; final int relativeOffset = location - (location() + 1); if (relativeOffset <= 0 && relativeOffset >= -255) { - add1(Bytecode.BR_IF_U8); + branchOp.emitOpcodesU8(this); // target add1(-relativeOffset); - // profile - addProfile(); } else { - add1(Bytecode.BR_IF_I32); + branchOp.emitOpcodesI32(this); // target add4(relativeOffset); - // profile - addProfile(); } + // profile + branchOp.emitProfile(this); } /** - * Adds a br_if_i32 opcode to the bytecode and reserves an i32 value for the jump offset. In + * Adds a branch opcode to the bytecode and reserves an i32 value for the jump offset. In * addition, a profile with a size of 2-byte is added. - * + * * @return The location of the jump offset to be patched later. (see * {@link #patchLocation(int, int)}) */ - public int addBranchIfLocation() { - add1(Bytecode.BR_IF_I32); + public int addBranchLocation(BranchOp branchOp) { + branchOp.emitOpcodesI32(this); final int location = location(); // target add4(0); // profile - addProfile(); + branchOp.emitProfile(this); return location; } @@ -537,9 +544,8 @@ public void addCall(int nodeIndex, int functionIndex) { /** * Adds an indirect call instruction to the bytecode. If the nodeIndex, typeIndex, and * tableIndex all fit into a u8 value, a call_indirect_u8 and three u8 values are added. - * Otherwise, a call_indirect_i32 and three i32 values are added. In both cases, a 2-byte - * profile is added. - * + * Otherwise, a call_indirect_i32 and three i32 values are added. + * * @param nodeIndex The node index of the indirect call * @param typeIndex The type index of the indirect call * @param tableIndex The table index of the indirect call @@ -558,6 +564,26 @@ public void addIndirectCall(int nodeIndex, int typeIndex, int tableIndex) { } } + /** + * Adds a reference call instruction to the bytecode. If the nodeIndex and typeIndex both fit + * into a u8 value, a call_ref_u8 and two u8 values are added. Otherwise, a call_ref_i32 and two + * i32 values are added. + * + * @param nodeIndex The node index of the reference call + * @param typeIndex The type index of the reference call + */ + public void addRefCall(int nodeIndex, int typeIndex) { + if (fitsIntoUnsignedByte(nodeIndex) && fitsIntoUnsignedByte(typeIndex)) { + add1(Bytecode.CALL_REF_U8); + add1(nodeIndex); + add1(typeIndex); + } else { + add1(Bytecode.CALL_REF_I32); + add4(nodeIndex); + add4(typeIndex); + } + } + public void addSelect(int instruction) { add1(instruction); addProfile(); @@ -692,20 +718,26 @@ public void addDataRuntimeHeader(int length) { * @param offsetAddress The offset address of the elem segment, -1 if missing * @return The location after the header in the bytecode */ - public int addElemHeader(int mode, int count, byte elemType, int tableIndex, byte[] offsetBytecode, int offsetAddress) { + public int addElemHeader(int mode, int count, int elemType, int tableIndex, byte[] offsetBytecode, int offsetAddress) { assert offsetBytecode == null || offsetAddress == -1 : "elem header does not allow offset bytecode and offset address"; assert mode == SegmentMode.ACTIVE || mode == SegmentMode.PASSIVE || mode == SegmentMode.DECLARATIVE : "invalid segment mode in elem header"; assert WasmType.isReferenceType(elemType) : "invalid elem type in elem header"; - int location = location(); + int flagsLocation = location(); + add1(0); + int modeLocation = location(); add1(0); - final int type = switch (elemType) { - case WasmType.FUNCREF_TYPE -> BytecodeBitEncoding.ELEM_SEG_TYPE_FUNREF; - case WasmType.EXTERNREF_TYPE -> BytecodeBitEncoding.ELEM_SEG_TYPE_EXTERNREF; - case WasmType.EXNREF_TYPE -> BytecodeBitEncoding.ELEM_SEG_TYPE_EXNREF; - default -> throw CompilerDirectives.shouldNotReachHere(); - }; - add1(type | mode); + int typeLengthAndMode = mode; + if (fitsIntoSignedByte(elemType)) { + typeLengthAndMode |= BytecodeBitEncoding.ELEM_SEG_TYPE_I8; + add1(elemType); + } else if (fitsIntoSignedShort(elemType)) { + typeLengthAndMode |= BytecodeBitEncoding.ELEM_SEG_TYPE_I16; + add2(elemType); + } else { + typeLengthAndMode |= BytecodeBitEncoding.ELEM_SEG_TYPE_I32; + add4(elemType); + } int flags = 0; if (fitsIntoUnsignedByte(count)) { flags |= BytecodeBitEncoding.ELEM_SEG_COUNT_U8; @@ -754,7 +786,8 @@ public int addElemHeader(int mode, int count, byte elemType, int tableIndex, byt add4(offsetAddress); } } - set(location, (byte) flags); + set(flagsLocation, (byte) flags); + set(modeLocation, (byte) typeLengthAndMode); return location(); } @@ -767,6 +800,15 @@ public void addByte(byte value) { add1(value); } + /** + * Adds a value type to the bytecode. + * + * @param type The value type that should be added + */ + public void addType(int type) { + add4(type); + } + /** * Adds a null entry to the data of an elem segment. */ diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/ir/CodeEntry.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/ir/CodeEntry.java index 2e7624f7d01c..4dff60c0b04c 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/ir/CodeEntry.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/ir/CodeEntry.java @@ -49,15 +49,15 @@ public final class CodeEntry { private final int functionIndex; private final int maxStackSize; - private final byte[] localTypes; - private final byte[] resultTypes; + private final int[] localTypes; + private final int[] resultTypes; private final List callNodes; private final int bytecodeStartOffset; private final int bytecodeEndOffset; private final boolean usesMemoryZero; private final int exceptionTableOffset; - public CodeEntry(int functionIndex, int maxStackSize, byte[] localTypes, byte[] resultTypes, List callNodes, int startOffset, int endOffset, boolean usesMemoryZero, + public CodeEntry(int functionIndex, int maxStackSize, int[] localTypes, int[] resultTypes, List callNodes, int startOffset, int endOffset, boolean usesMemoryZero, int exceptionTableOffset) { this.functionIndex = functionIndex; this.maxStackSize = maxStackSize; @@ -78,11 +78,11 @@ public int functionIndex() { return functionIndex; } - public byte[] localTypes() { + public int[] localTypes() { return localTypes; } - public byte[] resultTypes() { + public int[] resultTypes() { return resultTypes; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java index 3f164c40c9d9..046958abc330 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java @@ -42,7 +42,10 @@ package org.graalvm.wasm.parser.validation; import java.util.ArrayList; +import java.util.BitSet; +import org.graalvm.wasm.SymbolTable; +import org.graalvm.wasm.WasmType; import org.graalvm.wasm.collection.IntArrayList; import org.graalvm.wasm.exception.Failure; import org.graalvm.wasm.exception.WasmException; @@ -55,14 +58,28 @@ class BlockFrame extends ControlFrame { private final IntArrayList branches; private final ArrayList exceptionHandlers; - BlockFrame(byte[] paramTypes, byte[] resultTypes, int initialStackSize, boolean unreachable) { - super(paramTypes, resultTypes, initialStackSize, unreachable); + private BlockFrame(int[] paramTypes, int[] resultTypes, SymbolTable symbolTable, int initialStackSize, BitSet initializedLocals) { + super(paramTypes, resultTypes, symbolTable, initialStackSize, initializedLocals); branches = new IntArrayList(); exceptionHandlers = new ArrayList<>(); } + BlockFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, ControlFrame parentFrame) { + this(paramTypes, resultTypes, parentFrame.getSymbolTable(), initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); + } + + static BlockFrame createFunctionFrame(int[] paramTypes, int[] resultTypes, int[] locals, SymbolTable symbolTable) { + BitSet initializedLocals = new BitSet(locals.length); + for (int localIndex = 0; localIndex < locals.length; localIndex++) { + if (localIndex < paramTypes.length || WasmType.hasDefaultValue(locals[localIndex])) { + initializedLocals.set(localIndex); + } + } + return new BlockFrame(paramTypes, resultTypes, symbolTable, 0, initializedLocals); + } + @Override - byte[] labelTypes() { + int[] labelTypes() { return resultTypes(); } @@ -86,13 +103,8 @@ void exit(RuntimeBytecodeGen bytecode) { } @Override - void addBranch(RuntimeBytecodeGen bytecode) { - branches.add(bytecode.addBranchLocation()); - } - - @Override - void addBranchIf(RuntimeBytecodeGen bytecode) { - branches.add(bytecode.addBranchIfLocation()); + void addBranch(RuntimeBytecodeGen bytecode, RuntimeBytecodeGen.BranchOp branchOp) { + branches.add(bytecode.addBranchLocation(branchOp)); } @Override diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java index 4545dffbb890..87a443c92401 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java @@ -41,39 +41,49 @@ package org.graalvm.wasm.parser.validation; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmType; import org.graalvm.wasm.parser.bytecode.RuntimeBytecodeGen; +import java.util.BitSet; + /** * Represents the scope of a block structure during module validation. */ public abstract class ControlFrame { - private final byte[] paramTypes; - private final byte[] resultTypes; + private final int[] paramTypes; + private final int[] resultTypes; + private final SymbolTable symbolTable; private final int initialStackSize; private boolean unreachable; private final int commonResultType; + protected BitSet initializedLocals; /** * @param paramTypes The parameter value types of the block structure. * @param resultTypes The result value types of the block structure. + * @param symbolTable Necessary to look up the definitions of types in {@code paramTypes} and + * {@code resultTypes} * @param initialStackSize The size of the value stack when entering this block structure. - * @param unreachable If the block structure should be declared unreachable. + * @param initializedLocals The set of locals which are already initialized at the start of this + * function */ - ControlFrame(byte[] paramTypes, byte[] resultTypes, int initialStackSize, boolean unreachable) { + ControlFrame(int[] paramTypes, int[] resultTypes, SymbolTable symbolTable, int initialStackSize, BitSet initializedLocals) { this.paramTypes = paramTypes; this.resultTypes = resultTypes; + this.symbolTable = symbolTable; this.initialStackSize = initialStackSize; - this.unreachable = unreachable; + this.unreachable = false; commonResultType = WasmType.getCommonValueType(resultTypes); + this.initializedLocals = (BitSet) initializedLocals.clone(); } - protected byte[] paramTypes() { + protected int[] paramTypes() { return paramTypes; } - public byte[] resultTypes() { + public int[] resultTypes() { return resultTypes; } @@ -81,6 +91,10 @@ protected int resultTypeLength() { return resultTypes.length; } + protected SymbolTable getSymbolTable() { + return symbolTable; + } + /** * @return The union of all result types. */ @@ -91,7 +105,7 @@ protected int commonResultType() { /** * @return The types that must be on the value stack when branching to this frame. */ - abstract byte[] labelTypes(); + abstract int[] labelTypes(); protected int labelTypeLength() { return labelTypes().length; @@ -113,6 +127,14 @@ protected void resetUnreachable() { this.unreachable = false; } + boolean isLocalInitialized(int localIndex) { + return initializedLocals.get(localIndex); + } + + void initializeLocal(int localIndex) { + initializedLocals.set(localIndex); + } + /** * Performs checks and actions when entering an else branch. * @@ -129,21 +151,12 @@ protected void resetUnreachable() { abstract void exit(RuntimeBytecodeGen bytecode); /** - * Adds an unconditional branch targeting this control frame. Automatically patches the branch - * target as soon as it is available. - * - * @param bytecode The bytecode of the current control frame. - */ - abstract void addBranch(RuntimeBytecodeGen bytecode); - - /** - * Adds a conditional branch targeting this control frame. Automatically patches the branch - * target as soon as it is available. + * Adds a branch targeting this control frame. Automatically patches the branch target as soon + * as it is available. * * @param bytecode The bytecode of the current control frame. */ - - abstract void addBranchIf(RuntimeBytecodeGen bytecode); + abstract void addBranch(RuntimeBytecodeGen bytecode, RuntimeBytecodeGen.BranchOp branchOp); /** * Adds a branch table item targeting this control frame. Automatically patches the branch diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java index fef91318d510..1fef69c27988 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java @@ -42,7 +42,7 @@ package org.graalvm.wasm.parser.validation; import java.util.ArrayList; -import java.util.Arrays; +import java.util.BitSet; import org.graalvm.wasm.collection.IntArrayList; import org.graalvm.wasm.exception.Failure; @@ -56,25 +56,28 @@ class IfFrame extends ControlFrame { private final IntArrayList branchTargets; private final ArrayList exceptionHandlers; + private final ControlFrame parentFrame; private int falseJumpLocation; private boolean elseBranch; - IfFrame(byte[] paramTypes, byte[] resultTypes, int initialStackSize, boolean unreachable, int falseJumpLocation) { - super(paramTypes, resultTypes, initialStackSize, unreachable); - branchTargets = new IntArrayList(); - exceptionHandlers = new ArrayList<>(); + IfFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, ControlFrame parentFrame, int falseJumpLocation) { + super(paramTypes, resultTypes, parentFrame.getSymbolTable(), initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); + this.branchTargets = new IntArrayList(); + this.exceptionHandlers = new ArrayList<>(); + this.parentFrame = parentFrame; this.falseJumpLocation = falseJumpLocation; this.elseBranch = false; } @Override - byte[] labelTypes() { + int[] labelTypes() { return resultTypes(); } @Override void enterElse(ParserState state, RuntimeBytecodeGen bytecode) { - final int location = bytecode.addBranchLocation(); + initializedLocals = (BitSet) parentFrame.initializedLocals.clone(); + final int location = bytecode.addBranchLocation(RuntimeBytecodeGen.BranchOp.BR); bytecode.patchLocation(falseJumpLocation, bytecode.location()); falseJumpLocation = location; elseBranch = true; @@ -85,8 +88,17 @@ void enterElse(ParserState state, RuntimeBytecodeGen bytecode) { @Override void exit(RuntimeBytecodeGen bytecode) { - if (!elseBranch && !Arrays.equals(paramTypes(), resultTypes())) { - throw WasmException.create(Failure.TYPE_MISMATCH, "Expected else branch. If with incompatible param and result types requires else branch."); + if (!elseBranch) { + if (resultTypes().length != paramTypes().length) { + throw WasmException.create(Failure.TYPE_MISMATCH, "Expected else branch. If with incompatible param and result types requires else branch."); + } + if (!isUnreachable()) { + for (int i = 0; i < resultTypes().length; i++) { + if (!getSymbolTable().matchesType(resultTypes()[i], paramTypes()[i])) { + throw WasmException.create(Failure.TYPE_MISMATCH, "Expected else branch. If with incompatible param and result types requires else branch."); + } + } + } } if (branchTargets.size() == 0 && exceptionHandlers.isEmpty()) { bytecode.patchLocation(falseJumpLocation, bytecode.location()); @@ -103,13 +115,8 @@ void exit(RuntimeBytecodeGen bytecode) { } @Override - void addBranch(RuntimeBytecodeGen bytecode) { - branchTargets.add(bytecode.addBranchLocation()); - } - - @Override - void addBranchIf(RuntimeBytecodeGen bytecode) { - branchTargets.add(bytecode.addBranchIfLocation()); + void addBranch(RuntimeBytecodeGen bytecode, RuntimeBytecodeGen.BranchOp branchOp) { + branchTargets.add(bytecode.addBranchLocation(branchOp)); } @Override diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java index 71b69f9281db..0575e3560070 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java @@ -45,19 +45,21 @@ import org.graalvm.wasm.exception.WasmException; import org.graalvm.wasm.parser.bytecode.RuntimeBytecodeGen; +import java.util.BitSet; + /** * Representation of a wasm loop during module validation. */ class LoopFrame extends ControlFrame { private final int labelLocation; - LoopFrame(byte[] paramTypes, byte[] resultTypes, int initialStackSize, boolean unreachable, int labelLocation) { - super(paramTypes, resultTypes, initialStackSize, unreachable); + LoopFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, ControlFrame parentFrame, int labelLocation) { + super(paramTypes, resultTypes, parentFrame.getSymbolTable(), initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); this.labelLocation = labelLocation; } @Override - byte[] labelTypes() { + int[] labelTypes() { return paramTypes(); } @@ -71,13 +73,8 @@ void exit(RuntimeBytecodeGen bytecode) { } @Override - void addBranch(RuntimeBytecodeGen bytecode) { - bytecode.addBranch(labelLocation); - } - - @Override - void addBranchIf(RuntimeBytecodeGen bytecode) { - bytecode.addBranchIf(labelLocation); + void addBranch(RuntimeBytecodeGen bytecode, RuntimeBytecodeGen.BranchOp branchOp) { + bytecode.addBranch(labelLocation, branchOp); } @Override diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java index f2d4d3f5b307..63d783cf62c5 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java @@ -46,9 +46,10 @@ import java.util.ArrayList; import org.graalvm.wasm.Assert; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmType; import org.graalvm.wasm.api.Vector128; -import org.graalvm.wasm.collection.ByteArrayList; +import org.graalvm.wasm.collection.IntArrayList; import org.graalvm.wasm.constants.Bytecode; import org.graalvm.wasm.exception.Failure; import org.graalvm.wasm.exception.WasmException; @@ -59,22 +60,21 @@ * additional information used to generate parser nodes. */ public class ParserState { - private static final byte[] EMPTY_ARRAY = new byte[0]; - private static final byte ANY = 0; - - private final ByteArrayList valueStack; + private final IntArrayList valueStack; private final ControlStack controlStack; private final RuntimeBytecodeGen bytecode; private final ArrayList exceptionTables; + private final SymbolTable symbolTable; private int maxStackSize; private boolean usesMemoryZero; - public ParserState(RuntimeBytecodeGen bytecode) { - this.valueStack = new ByteArrayList(); + public ParserState(RuntimeBytecodeGen bytecode, SymbolTable symbolTable) { + this.valueStack = new IntArrayList(); this.controlStack = new ControlStack(); this.bytecode = bytecode; this.exceptionTables = new ArrayList<>(); + this.symbolTable = symbolTable; this.maxStackSize = 0; } @@ -85,13 +85,13 @@ public ParserState(RuntimeBytecodeGen bytecode) { * @param expectedValueType The expectedValueType used for error generation. * @return The top of the stack or -1. */ - private byte popInternal(byte expectedValueType) { + private int popInternal(int expectedValueType) { if (availableStackSize() == 0) { if (isCurrentStackUnreachable()) { - return WasmType.UNKNOWN_TYPE; + return WasmType.BOT; } else { - if (expectedValueType == ANY) { - throw ValidationErrors.createExpectedAnyOnEmptyStack(); + if (expectedValueType == WasmType.TOP) { + throw ValidationErrors.createExpectedTopOnEmptyStack(); } else { throw ValidationErrors.createExpectedTypeOnEmptyStack(expectedValueType); } @@ -101,7 +101,7 @@ private byte popInternal(byte expectedValueType) { } /** - * Pops the maximum available values form the current stack frame. If the number of values on + * Pops the maximum available values from the current stack frame. If the number of values on * the stack is smaller than the number of expectedValueTypes, only the remaining stack values * are returned. If the number of values on the stack is greater or equal to the number of * expectedValueTypes, the values equal to the number of expectedValueTypes is popped from the @@ -110,10 +110,10 @@ private byte popInternal(byte expectedValueType) { * @param expectedValueTypes Value types expected on the stack. * @return The maximum of available stack values smaller than the length of expectedValueTypes. */ - private byte[] popAvailableUnchecked(byte[] expectedValueTypes) { + private int[] popAvailableUnchecked(int[] expectedValueTypes) { int availableStackSize = availableStackSize(); int availableSize = Math.min(availableStackSize, expectedValueTypes.length); - byte[] popped = new byte[availableSize]; + int[] popped = new int[availableSize]; for (int i = availableSize - 1; i >= 0; i--) { popped[i] = popInternal(expectedValueTypes[i]); } @@ -125,12 +125,12 @@ private byte[] popAvailableUnchecked(byte[] expectedValueTypes) { * * @return The maximum of available stack values. */ - private byte[] popAvailableUnchecked() { + private int[] popAvailableUnchecked() { int availableStackSize = availableStackSize(); - byte[] popped = new byte[availableStackSize]; + int[] popped = new int[availableStackSize]; int j = 0; for (int i = availableStackSize - 1; i >= 0; i--) { - popped[j] = popInternal(ANY); + popped[j] = popInternal(WasmType.TOP); j++; } return popped; @@ -143,7 +143,7 @@ private byte[] popAvailableUnchecked() { * @param actualTypes The actual value types. * @return True if both are equivalent. */ - private boolean isTypeMismatch(byte[] expectedTypes, byte[] actualTypes) { + private boolean isTypeMismatch(int[] expectedTypes, int[] actualTypes) { if (expectedTypes.length != actualTypes.length) { return true; } @@ -151,7 +151,7 @@ private boolean isTypeMismatch(byte[] expectedTypes, byte[] actualTypes) { return false; } for (int i = 0; i < expectedTypes.length; i++) { - if (expectedTypes[i] != actualTypes[i]) { + if (!symbolTable.matchesType(expectedTypes[i], actualTypes[i])) { return true; } } @@ -163,8 +163,8 @@ private boolean isTypeMismatch(byte[] expectedTypes, byte[] actualTypes) { * * @param valueType The value type that should be added. */ - public void push(byte valueType) { - valueStack.push(valueType); + public void push(int valueType) { + valueStack.add(valueType); maxStackSize = Math.max(valueStack.size(), maxStackSize); } @@ -173,8 +173,8 @@ public void push(byte valueType) { * * @param valueTypes The value types that should be added. */ - public void pushAll(byte[] valueTypes) { - for (byte valueType : valueTypes) { + public void pushAll(int[] valueTypes) { + for (int valueType : valueTypes) { push(valueType); } } @@ -185,8 +185,8 @@ public void pushAll(byte[] valueTypes) { * @return The value type on top of the stack or -1. * @throws WasmException If the stack is empty. */ - public byte pop() { - return popInternal(ANY); + public int pop() { + return popInternal(WasmType.TOP); } /** @@ -196,9 +196,9 @@ public byte pop() { * @return The value type on top of the stack. * @throws WasmException If the stack is empty or the value types do not match. */ - public byte popChecked(byte expectedValueType) { - final byte actualValueType = popInternal(expectedValueType); - if (actualValueType != expectedValueType && actualValueType != WasmType.UNKNOWN_TYPE && expectedValueType != WasmType.UNKNOWN_TYPE) { + public int popChecked(int expectedValueType) { + final int actualValueType = popInternal(expectedValueType); + if (!symbolTable.matchesType(expectedValueType, actualValueType)) { throw ValidationErrors.createTypeMismatch(expectedValueType, actualValueType); } return actualValueType; @@ -208,18 +208,19 @@ public byte popChecked(byte expectedValueType) { * Pops the topmost value type from the stack and checks if it is a reference type. * * @throws WasmException If the stack is empty or the value type is not a reference type. + * @return The reference type on top of the stack. */ - public void popReferenceTypeChecked() { + public int popReferenceTypeChecked() { if (availableStackSize() != 0) { - final byte value = valueStack.popBack(); + final int value = valueStack.popBack(); if (WasmType.isReferenceType(value)) { - return; + return value; } // Push value back onto the stack and perform a checked pop to get the correct error // message valueStack.push(value); } - popChecked(WasmType.FUNCREF_TYPE); + return popChecked(WasmType.FUNCREF_TYPE); } /** @@ -230,8 +231,8 @@ public void popReferenceTypeChecked() { * @return The value types on top of the stack. * @throws WasmException If the stack is empty or the value types do not match. */ - public byte[] popAll(byte[] expectedValueTypes) { - byte[] popped = new byte[expectedValueTypes.length]; + public int[] popAll(int[] expectedValueTypes) { + int[] popped = new int[expectedValueTypes.length]; for (int i = expectedValueTypes.length - 1; i >= 0; i--) { popped[i] = popChecked(expectedValueTypes[i]); } @@ -264,8 +265,9 @@ private void unwindStack(int size) { } } - public void enterFunction(byte[] resultTypes) { - enterBlock(EMPTY_ARRAY, resultTypes); + public void enterFunction(int[] paramTypes, int[] resultTypes, int[] locals) { + ControlFrame frame = BlockFrame.createFunctionFrame(paramTypes, resultTypes, locals, symbolTable); + controlStack.push(frame); } /** @@ -275,8 +277,8 @@ public void enterFunction(byte[] resultTypes) { * @param paramTypes The param types of the block that was entered. * @param resultTypes The result types of the block that was entered. */ - public void enterBlock(byte[] paramTypes, byte[] resultTypes) { - ControlFrame frame = new BlockFrame(paramTypes, resultTypes, valueStack.size(), false); + public void enterBlock(int[] paramTypes, int[] resultTypes) { + ControlFrame frame = new BlockFrame(paramTypes, resultTypes, valueStack.size(), controlStack.peek()); controlStack.push(frame); pushAll(paramTypes); } @@ -288,9 +290,9 @@ public void enterBlock(byte[] paramTypes, byte[] resultTypes) { * @param paramTypes The param types of the loop that was entered. * @param resultTypes The result types of the loop that was entered. */ - public void enterLoop(byte[] paramTypes, byte[] resultTypes) { + public void enterLoop(int[] paramTypes, int[] resultTypes) { final int label = bytecode.addLoopLabel(paramTypes.length, valueStack.size(), WasmType.getCommonValueType(resultTypes)); - ControlFrame frame = new LoopFrame(paramTypes, resultTypes, valueStack.size(), false, label); + ControlFrame frame = new LoopFrame(paramTypes, resultTypes, valueStack.size(), controlStack.peek(), label); controlStack.push(frame); pushAll(paramTypes); } @@ -302,9 +304,9 @@ public void enterLoop(byte[] paramTypes, byte[] resultTypes) { * @param paramTypes The param types of the if and else branch that was entered. * @param resultTypes The result type of the if and else branch that was entered. */ - public void enterIf(byte[] paramTypes, byte[] resultTypes) { + public void enterIf(int[] paramTypes, int[] resultTypes) { final int fixupLocation = bytecode.addIfLocation(); - ControlFrame frame = new IfFrame(paramTypes, resultTypes, valueStack.size(), false, fixupLocation); + ControlFrame frame = new IfFrame(paramTypes, resultTypes, valueStack.size(), controlStack.peek(), fixupLocation); controlStack.push(frame); pushAll(paramTypes); } @@ -326,8 +328,8 @@ public void enterElse() { * @param resultTypes The result types of the try table that was entered. * @param handlers The exception handlers of the try table that was entered. */ - public void enterTryTable(byte[] paramTypes, byte[] resultTypes, ExceptionHandler[] handlers) { - final TryTableFrame frame = new TryTableFrame(paramTypes, resultTypes, valueStack.size(), false, bytecode.location(), handlers); + public void enterTryTable(int[] paramTypes, int[] resultTypes, ExceptionHandler[] handlers) { + final TryTableFrame frame = new TryTableFrame(paramTypes, resultTypes, valueStack.size(), controlStack.peek(), bytecode.location(), handlers); controlStack.push(frame); exceptionTables.add(frame.table()); @@ -347,7 +349,7 @@ public ExceptionHandler enterCatchClause(int opcode, int tag, int label) { checkLabelExists(label); final ControlFrame labelFrame = getFrame(label); // we reuse the block frame, instead of introducing a new catch frame. - final ControlFrame frame = new BlockFrame(WasmType.VOID_TYPE_ARRAY, labelFrame.labelTypes(), labelFrame.initialStackSize(), false); + final ControlFrame frame = new BlockFrame(WasmType.VOID_TYPE_ARRAY, labelFrame.labelTypes(), labelFrame.initialStackSize(), controlStack.peek()); controlStack.push(frame); final ExceptionHandler e = new ExceptionHandler(opcode, tag); labelFrame.addExceptionHandler(e); @@ -396,10 +398,10 @@ public void addSelectInstruction(int instruction) { public void addConditionalBranch(int branchLabel) { checkLabelExists(branchLabel); ControlFrame frame = getFrame(branchLabel); - final byte[] labelTypes = frame.labelTypes(); + final int[] labelTypes = frame.labelTypes(); popAll(labelTypes); pushAll(labelTypes); - frame.addBranchIf(bytecode); + frame.addBranch(bytecode, RuntimeBytecodeGen.BranchOp.BR_IF); } /** @@ -411,9 +413,37 @@ public void addConditionalBranch(int branchLabel) { public void addUnconditionalBranch(int branchLabel) { checkLabelExists(branchLabel); ControlFrame frame = getFrame(branchLabel); - final byte[] labelTypes = frame.labelTypes(); + final int[] labelTypes = frame.labelTypes(); + popAll(labelTypes); + frame.addBranch(bytecode, RuntimeBytecodeGen.BranchOp.BR); + } + + public void addBranchOnNull(int branchLabel) { + checkLabelExists(branchLabel); + ControlFrame frame = getFrame(branchLabel); + final int[] labelTypes = frame.labelTypes(); popAll(labelTypes); - frame.addBranch(bytecode); + pushAll(labelTypes); + frame.addBranch(bytecode, RuntimeBytecodeGen.BranchOp.BR_ON_NULL); + } + + public void addBranchOnNonNull(int branchLabel, int referenceType) { + checkLabelExists(branchLabel); + ControlFrame frame = getFrame(branchLabel); + final int[] labelTypes = frame.labelTypes(); + if (labelTypes.length < 1) { + throw ValidationErrors.createLabelTypesMismatch(labelTypes, new int[]{referenceType}); + } + if (!symbolTable.matchesType(labelTypes[labelTypes.length - 1], referenceType)) { + throw ValidationErrors.createTypeMismatch(labelTypes[labelTypes.length - 1], referenceType); + } + for (int i = labelTypes.length - 2; i >= 0; i--) { + popChecked(labelTypes[i]); + } + for (int i = 0; i < labelTypes.length - 1; i++) { + push(labelTypes[i]); + } + frame.addBranch(bytecode, RuntimeBytecodeGen.BranchOp.BR_ON_NON_NULL); } /** @@ -427,13 +457,20 @@ public void addBranchTable(int[] branchLabels) { int branchLabel = branchLabels[branchLabels.length - 1]; checkLabelExists(branchLabel); ControlFrame frame = getFrame(branchLabel); - byte[] branchLabelReturnTypes = frame.labelTypes(); + int[] branchLabelReturnTypes = frame.labelTypes(); + int arity = branchLabelReturnTypes.length; for (int otherBranchLabel : branchLabels) { checkLabelExists(otherBranchLabel); frame = getFrame(otherBranchLabel); - byte[] otherBranchLabelReturnTypes = frame.labelTypes(); - checkLabelTypes(branchLabelReturnTypes, otherBranchLabelReturnTypes); - pushAll(popAll(otherBranchLabelReturnTypes)); + int[] otherBranchLabelReturnTypes = frame.labelTypes(); + if (otherBranchLabelReturnTypes.length != arity) { + throw ValidationErrors.createLabelTypesMismatch(branchLabelReturnTypes, otherBranchLabelReturnTypes); + } + try { + pushAll(popAll(otherBranchLabelReturnTypes)); + } catch (WasmException e) { + throw ValidationErrors.createLabelTypesMismatch(branchLabelReturnTypes, otherBranchLabelReturnTypes); + } frame.addBranchTableItem(bytecode); } popAll(branchLabelReturnTypes); @@ -455,18 +492,34 @@ public void addReturn(boolean multiValue) { } /** - * Adds the index of an indirect call node to the extra data array. + * Adds a reference call instruction to the bytecode, along with its immediate argument and the + * call node index. + * + * @param nodeIndex The index of the call node associated with this call instruction. + * @param typeIndex The index of the defined function type. + */ + public void addRefCall(int nodeIndex, int typeIndex) { + bytecode.addRefCall(nodeIndex, typeIndex); + } + + /** + * Adds an indirect call instruction to the bytecode, along with its immediate arguments and the + * call node index. * - * @param nodeIndex The index of the indirect call. + * @param nodeIndex The index of the call node associated with this call instruction. + * @param typeIndex The index of the defined function type. + * @param tableIndex The index of the table in which the function will be looked up. */ public void addIndirectCall(int nodeIndex, int typeIndex, int tableIndex) { bytecode.addIndirectCall(nodeIndex, typeIndex, tableIndex); } /** - * Adds the index of a direct call node to the extra data array. + * Adds a direct call instruction to the bytecode, along with its immediate argument and the + * call node index. * - * @param nodeIndex The index of the direct call. + * @param nodeIndex The index of the call node associated with this call instruction. + * @param functionIndex The index of the defined function. */ public void addCall(int nodeIndex, int functionIndex) { bytecode.addCall(nodeIndex, functionIndex); @@ -633,10 +686,10 @@ public void addVectorLaneInstruction(int instruction, byte laneIndex) { * * @return The types of the return values of the current frame. */ - public byte[] exit(boolean multiValue) { + public int[] exit(boolean multiValue) { Assert.assertTrue(!controlStack.isEmpty(), Failure.UNEXPECTED_END_OF_BLOCK); ControlFrame frame = controlStack.peek(); - byte[] resultTypes = frame.resultTypes(); + int[] resultTypes = frame.resultTypes(); frame.exit(bytecode); checkStackAfterFrameExit(frame, resultTypes); @@ -653,9 +706,9 @@ public byte[] exit(boolean multiValue) { * @param frame The frame that is exited. * @param resultTypes The expected return types of the frame. */ - void checkStackAfterFrameExit(ControlFrame frame, byte[] resultTypes) { + void checkStackAfterFrameExit(ControlFrame frame, int[] resultTypes) { if (availableStackSize() > resultTypes.length) { - byte[] actualTypes = popAvailableUnchecked(); + int[] actualTypes = popAvailableUnchecked(); if (isTypeMismatch(resultTypes, actualTypes)) { throw ValidationErrors.createResultTypesMismatch(resultTypes, actualTypes); } @@ -678,6 +731,14 @@ public ControlFrame getRootBlock() { return controlStack.getFirst(); } + public boolean isLocalInitialized(int localIndex) { + return controlStack.peek().isLocalInitialized(localIndex); + } + + public void initializeLocal(int localIndex) { + controlStack.peek().initializeLocal(localIndex); + } + /** * Checks if the return value types of the given control frame match the remaining value types * on the stack. @@ -687,11 +748,11 @@ public ControlFrame getRootBlock() { * stack. */ private void checkResultTypes(ControlFrame frame) { - byte[] resultTypes = frame.resultTypes(); + int[] resultTypes = frame.resultTypes(); if (isCurrentStackUnreachable()) { popAll(resultTypes); } else { - byte[] actualTypes = popAvailableUnchecked(resultTypes); + int[] actualTypes = popAvailableUnchecked(resultTypes); if (isTypeMismatch(resultTypes, actualTypes)) { throw ValidationErrors.createResultTypesMismatch(resultTypes, actualTypes); } @@ -705,11 +766,11 @@ private void checkResultTypes(ControlFrame frame) { * @throws WasmException If the parameter value types and the vale types on the stack do not * match. */ - public void checkParamTypes(byte[] paramTypes) { + public void checkParamTypes(int[] paramTypes) { if (isCurrentStackUnreachable()) { popAll(paramTypes); } else { - byte[] actualTypes = popAvailableUnchecked(paramTypes); + int[] actualTypes = popAvailableUnchecked(paramTypes); if (isTypeMismatch(paramTypes, actualTypes)) { throw ValidationErrors.createParamTypesMismatch(paramTypes, actualTypes); } @@ -728,32 +789,6 @@ public void checkLabelExists(int label) { } } - /** - * Checks if the value types of two different labels match. - * - * @param expectedTypes The expected value types. - * @param actualTypes The value types that should be checked. - * @throws WasmException If the provided sets of value types do not match. - */ - public void checkLabelTypes(byte[] expectedTypes, byte[] actualTypes) { - if (isTypeMismatch(expectedTypes, actualTypes)) { - throw ValidationErrors.createLabelTypesMismatch(expectedTypes, actualTypes); - } - } - - /** - * Checks if the given function type is within range. - * - * @param typeIndex The function type. - * @param max The number of available function types. - * @throws WasmException If the given function type is greater or equal to the given maximum. - */ - public void checkFunctionTypeExists(int typeIndex, int max) { - if (compareUnsigned(typeIndex, max) >= 0) { - throw ValidationErrors.createMissingFunctionType(typeIndex, max - 1); - } - } - /** * Sets the current control stack unreachable. */ diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java index 271b33f0c4be..ddfe213a700a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java @@ -49,8 +49,8 @@ public class TryTableFrame extends BlockFrame { private final ExceptionTable table; - TryTableFrame(byte[] paramTypes, byte[] resultTypes, int initialStackSize, boolean unreachable, int startOffset, ExceptionHandler[] handlers) { - super(paramTypes, resultTypes, initialStackSize, unreachable); + TryTableFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, ControlFrame parentFrame, int startOffset, ExceptionHandler[] handlers) { + super(paramTypes, resultTypes, initialStackSize, parentFrame); this.table = new ExceptionTable(startOffset, handlers); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java index 3e55f0dfec8c..26f9e8d926d5 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java @@ -50,25 +50,10 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; public class ValidationErrors { - private static String getValueTypeString(byte valueType) { - return switch (valueType) { - case WasmType.VOID_TYPE -> ""; - case WasmType.I32_TYPE -> "i32"; - case WasmType.I64_TYPE -> "i64"; - case WasmType.F32_TYPE -> "f32"; - case WasmType.F64_TYPE -> "f64"; - case WasmType.V128_TYPE -> "v128"; - case WasmType.FUNCREF_TYPE -> "funcref"; - case WasmType.EXTERNREF_TYPE -> "externref"; - case WasmType.EXNREF_TYPE -> "exnref"; - default -> "unknown"; - }; - } - - private static String getValueTypesString(byte[] valueTypes) { + private static String getValueTypesString(int[] valueTypes) { StringJoiner stringJoiner = new StringJoiner(","); - for (byte valueType : valueTypes) { - stringJoiner.add(getValueTypeString(valueType)); + for (int valueType : valueTypes) { + stringJoiner.add(WasmType.toString(valueType)); } return stringJoiner.toString(); } @@ -78,28 +63,28 @@ private static WasmException create(String message, Object expected, Object actu } @TruffleBoundary - public static WasmException createTypeMismatch(byte expectedType, byte actualType) { - String expectedTypeString = getValueTypeString(expectedType); - String actualTypeString = getValueTypeString(actualType); + public static WasmException createTypeMismatch(int expectedType, int actualType) { + String expectedTypeString = WasmType.toString(expectedType); + String actualTypeString = WasmType.toString(actualType); return create("Expected type [%s], but got [%s].", expectedTypeString, actualTypeString); } @TruffleBoundary - public static WasmException createResultTypesMismatch(byte[] expectedTypes, byte[] actualTypes) { + public static WasmException createResultTypesMismatch(int[] expectedTypes, int[] actualTypes) { String expectedTypesString = getValueTypesString(expectedTypes); String actualTypesString = getValueTypesString(actualTypes); return create("Expected result types [%s], but got [%s].", expectedTypesString, actualTypesString); } @TruffleBoundary - public static WasmException createLabelTypesMismatch(byte[] expectedTypes, byte[] actualTypes) { + public static WasmException createLabelTypesMismatch(int[] expectedTypes, int[] actualTypes) { String expectedTypesString = getValueTypesString(expectedTypes); String actualTypesString = getValueTypesString(actualTypes); return create("Inconsistent label types. Expected [%s], but got [%s].", expectedTypesString, actualTypesString); } @TruffleBoundary - public static WasmException createParamTypesMismatch(byte[] expectedTypes, byte[] actualTypes) { + public static WasmException createParamTypesMismatch(int[] expectedTypes, int[] actualTypes) { String expectedTypesString = getValueTypesString(expectedTypes); String actualTypesString = getValueTypesString(actualTypes); return create("Expected param types [%s], but got [%s].", expectedTypesString, actualTypesString); @@ -110,19 +95,24 @@ public static WasmException createMissingLabel(int expected, int max) { return WasmException.format(Failure.UNKNOWN_LABEL, "Unknown branch label %d (max %d).", expected, max); } + @TruffleBoundary + public static WasmException createMissingFunctionType(int expected) { + return WasmException.format(Failure.UNKNOWN_TYPE, "Function type variable %d out of range.", expected); + } + @TruffleBoundary public static WasmException createMissingFunctionType(int expected, int max) { return WasmException.format(Failure.UNKNOWN_TYPE, "Function type variable %d out of range. (max %d)", expected, max); } @TruffleBoundary - public static WasmException createExpectedAnyOnEmptyStack() { - return WasmException.create(Failure.TYPE_MISMATCH, "Expected type [any], but got []."); + public static WasmException createExpectedTopOnEmptyStack() { + return WasmException.create(Failure.TYPE_MISMATCH, "Expected type [top], but got []."); } @TruffleBoundary - public static WasmException createExpectedTypeOnEmptyStack(byte expectedType) { - String expectedTypeString = getValueTypeString(expectedType); + public static WasmException createExpectedTypeOnEmptyStack(int expectedType) { + String expectedTypeString = WasmType.toString(expectedType); return create("Expected type [%s], but got [].", expectedTypeString, ""); } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java index d5390e6a9521..b9afdc82bd81 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java @@ -97,7 +97,7 @@ public WasmInstance createInstance(WasmLanguage language, WasmStore store, Strin return instance; } - protected WasmFunction defineFunction(WasmContext context, WasmModule module, String name, byte[] paramTypes, byte[] retTypes, WasmRootNode rootNode) { + protected WasmFunction defineFunction(WasmContext context, WasmModule module, String name, int[] paramTypes, int[] retTypes, WasmRootNode rootNode) { // Must instantiate RootNode in the right language / sharing layer. assert context.language() == rootNode.getLanguage(WasmLanguage.class); // We could check if the same function type had already been allocated, @@ -109,27 +109,23 @@ protected WasmFunction defineFunction(WasmContext context, WasmModule module, St return function; } - protected int defineGlobal(WasmModule module, String name, byte valueType, byte mutability, Object value) { + protected int defineGlobal(WasmModule module, String name, int valueType, byte mutability, Object value) { int index = module.symbolTable().numGlobals(); module.symbolTable().declareExportedGlobalWithValue(name, index, valueType, mutability, value); return index; } - protected int defineTable(WasmContext context, WasmModule module, String tableName, int initSize, int maxSize, byte type) { + protected int defineTable(WasmContext context, WasmModule module, String tableName, int initSize, int maxSize, int type) { final boolean referenceTypes = context.getContextOptions().supportBulkMemoryAndRefTypes(); - switch (type) { - case WasmType.FUNCREF_TYPE: - break; - case WasmType.EXTERNREF_TYPE: - if (!referenceTypes) { - throw WasmException.create(Failure.UNSPECIFIED_MALFORMED, "Only function types are currently supported in tables."); - } - break; - default: - throw WasmException.create(Failure.MALFORMED_REFERENCE_TYPE, "Only reference types supported in tables."); + if (!WasmType.isReferenceType(type)) { + throw WasmException.create(Failure.MALFORMED_REFERENCE_TYPE, "Only reference types supported in tables."); + } else if (!referenceTypes && type != WasmType.FUNCREF_TYPE) { + throw WasmException.create(Failure.UNSPECIFIED_MALFORMED, "Only function types are currently supported in tables."); + } else if (!WasmType.isNullable(type)) { + throw WasmException.create(Failure.TYPE_MISMATCH, "Tables of built-in modules must be nullable."); } int index = module.symbolTable().tableCount(); - module.symbolTable().allocateTable(index, initSize, maxSize, type, referenceTypes); + module.symbolTable().declareTable(index, initSize, maxSize, type, null, null, referenceTypes); module.symbolTable().exportTable(index, tableName); return index; } @@ -143,7 +139,7 @@ protected void defineMemory(WasmContext context, WasmModule module, String memor module.symbolTable().exportMemory(index, memoryName); } - protected void importFunction(WasmContext context, WasmModule module, String importModuleName, String importFunctionName, byte[] paramTypes, byte[] retTypes, String exportName) { + protected void importFunction(WasmContext context, WasmModule module, String importModuleName, String importFunctionName, int[] paramTypes, int[] retTypes, String exportName) { final int typeIdx = module.symbolTable().allocateFunctionType(paramTypes, retTypes, context.getContextOptions().supportMultiValue()); final WasmFunction function = module.symbolTable().importFunction(importModuleName, importFunctionName, typeIdx); module.symbolTable().exportFunction(function.index(), exportName); @@ -155,7 +151,7 @@ protected void importMemory(WasmContext context, WasmModule module, String impor module.symbolTable().importMemory(importModuleName, memoryName, index, initSize, maxSize, is64Bit, isShared, multiMemory); } - protected static byte[] types(byte... args) { + protected static int[] types(int... args) { return args; } }