From 080d3ea7648edef252babd3c9a3dc509d30aef64 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 23 Sep 2025 10:18:15 +0200 Subject: [PATCH 01/40] Introduced typed function references to GraalWasm type system --- .../wasm/test/AbstractBinarySuite.java | 99 +++-- .../org/graalvm/wasm/test/WasmJsApiSuite.java | 3 +- .../test/suites/bytecode/BytecodeSuite.java | 48 +- .../debugging/DebugValidationSuite.java | 2 +- .../ReferenceTypesValidationSuite.java | 47 +- .../suites/validation/ValidationSuite.java | 2 +- .../src/org/graalvm/wasm/BinaryParser.java | 419 ++++++++++++------ .../org/graalvm/wasm/BinaryStreamParser.java | 83 +--- .../src/org/graalvm/wasm/GlobalRegistry.java | 9 +- .../src/org/graalvm/wasm/Linker.java | 65 +-- .../src/org/graalvm/wasm/ModuleLimits.java | 4 +- .../src/org/graalvm/wasm/SymbolTable.java | 299 +++++++++---- .../src/org/graalvm/wasm/WasmCodeEntry.java | 10 +- .../src/org/graalvm/wasm/WasmFunction.java | 14 +- .../org/graalvm/wasm/WasmInstantiator.java | 21 +- .../src/org/graalvm/wasm/WasmLanguage.java | 18 - .../src/org/graalvm/wasm/WasmTable.java | 13 +- .../src/org/graalvm/wasm/WasmType.java | 153 ++++--- .../wasm/api/ExecuteHostFunctionNode.java | 29 +- .../src/org/graalvm/wasm/api/FuncType.java | 18 +- .../wasm/api/InteropCallAdapterNode.java | 66 +-- .../src/org/graalvm/wasm/api/TableKind.java | 14 +- .../src/org/graalvm/wasm/api/ValueType.java | 15 +- .../org/graalvm/wasm/api/Vector128Shape.java | 6 +- .../src/org/graalvm/wasm/api/WebAssembly.java | 10 +- .../wasm/collection/ByteArrayList.java | 56 +-- .../graalvm/wasm/collection/IntArrayList.java | 13 + .../wasm/constants/BytecodeBitEncoding.java | 7 +- .../graalvm/wasm/nodes/WasmFunctionNode.java | 109 ++--- .../wasm/nodes/WasmFunctionRootNode.java | 124 ++---- .../wasm/parser/bytecode/BytecodeParser.java | 38 +- .../parser/bytecode/RuntimeBytecodeGen.java | 46 +- .../org/graalvm/wasm/parser/ir/CodeEntry.java | 10 +- .../wasm/parser/validation/BlockFrame.java | 4 +- .../wasm/parser/validation/ControlFrame.java | 12 +- .../wasm/parser/validation/IfFrame.java | 4 +- .../wasm/parser/validation/LoopFrame.java | 4 +- .../wasm/parser/validation/ParserState.java | 78 ++-- .../wasm/parser/validation/TryTableFrame.java | 2 +- .../parser/validation/ValidationErrors.java | 37 +- .../wasm/predefined/BuiltinModule.java | 24 +- 41 files changed, 1117 insertions(+), 918 deletions(-) diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/AbstractBinarySuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/AbstractBinarySuite.java index bc065856b093..358b3f203c3a 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/AbstractBinarySuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/AbstractBinarySuite.java @@ -53,9 +53,10 @@ import org.graalvm.polyglot.io.ByteSequence; import org.graalvm.wasm.WasmLanguage; import org.graalvm.wasm.collection.ByteArrayList; +import org.graalvm.wasm.collection.IntArrayList; public abstract class AbstractBinarySuite { - protected static final byte[] EMPTY_BYTES = {}; + protected static final int[] EMPTY_INTS = {}; protected static void runRuntimeTest(byte[] binary, Consumer options, Consumer testCase) throws IOException { final Context.Builder contextBuilder = Context.newBuilder(WasmLanguage.ID); @@ -100,10 +101,10 @@ private static byte getByte(String hexString) { private static final class BinaryTypes { - private final List paramEntries = new ArrayList<>(); - private final List resultEntries = new ArrayList<>(); + private final List paramEntries = new ArrayList<>(); + private final List resultEntries = new ArrayList<>(); - private void add(byte[] params, byte[] results) { + private void add(int[] params, int[] results) { paramEntries.add(params); resultEntries.add(results); } @@ -112,18 +113,18 @@ private byte[] generateTypeSection() { ByteArrayList b = new ByteArrayList(); b.add(getByte("01")); b.add((byte) 0); // length is patched in the end - b.add((byte) paramEntries.size()); + b.addUnsignedInt32(paramEntries.size()); for (int i = 0; i < paramEntries.size(); i++) { b.add(getByte("60")); - byte[] params = paramEntries.get(i); - byte[] results = resultEntries.get(i); - b.add((byte) params.length); - for (byte param : params) { - b.add(param); + int[] params = paramEntries.get(i); + int[] results = resultEntries.get(i); + b.addUnsignedInt32(params.length); + for (int param : params) { + b.addSignedInt32(param); } - b.add((byte) results.length); - for (byte result : results) { - b.add(result); + b.addUnsignedInt32(results.length); + for (int result : results) { + b.addSignedInt32(result); } } b.set(1, (byte) (b.size() - 2)); @@ -132,9 +133,9 @@ private byte[] generateTypeSection() { } private static final class BinaryTables { - private final ByteArrayList tables = new ByteArrayList(); + private final IntArrayList tables = new IntArrayList(); - private void add(byte initSize, byte maxSize, byte elemType) { + private void add(int initSize, int maxSize, int elemType) { tables.add(initSize); tables.add(maxSize); tables.add(elemType); @@ -147,10 +148,10 @@ private byte[] generateTableSection() { final int tableCount = tables.size() / 3; b.add((byte) tableCount); for (int i = 0; i < tables.size(); i += 3) { - b.add(tables.get(i + 2)); + b.addSignedInt32(tables.get(i + 2)); b.add(getByte("01")); - b.add(tables.get(i)); - b.add(tables.get(i + 1)); + b.addUnsignedInt32(tables.get(i)); + b.addUnsignedInt32(tables.get(i + 1)); } b.set(1, (byte) (b.size() - 2)); return b.toArray(); @@ -158,9 +159,9 @@ private byte[] generateTableSection() { } private static final class BinaryMemories { - private final ByteArrayList memories = new ByteArrayList(); + private final IntArrayList memories = new IntArrayList(); - private void add(byte initSize, byte maxSize) { + private void add(int initSize, int maxSize) { memories.add(initSize); memories.add(maxSize); } @@ -173,8 +174,8 @@ private byte[] generateMemorySection() { b.add((byte) memoryCount); for (int i = 0; i < memories.size(); i += 2) { b.add(getByte("01")); - b.add(memories.get(i)); - b.add(memories.get(i + 1)); + b.addUnsignedInt32(memories.get(i)); + b.addUnsignedInt32(memories.get(i + 1)); } b.set(1, (byte) (b.size() - 2)); return b.toArray(); @@ -182,11 +183,11 @@ private byte[] generateMemorySection() { } private static final class BinaryFunctions { - private final ByteArrayList types = new ByteArrayList(); - private final List localEntries = new ArrayList<>(); + private final IntArrayList types = new IntArrayList(); + private final List localEntries = new ArrayList<>(); private final List codeEntries = new ArrayList<>(); - private void add(byte typeIndex, byte[] locals, byte[] code) { + private void add(int typeIndex, int[] locals, byte[] code) { types.add(typeIndex); localEntries.add(locals); codeEntries.add(code); @@ -197,9 +198,9 @@ private byte[] generateFunctionSection() { b.add(getByte("03")); b.add((byte) 0); // length is patched at the end final int functionCount = types.size(); - b.add((byte) functionCount); + b.addUnsignedInt32(functionCount); for (int i = 0; i < functionCount; i++) { - b.add(types.get(i)); + b.addUnsignedInt32(types.get(i)); } b.set(1, (byte) (b.size() - 2)); return b.toArray(); @@ -210,15 +211,15 @@ private byte[] generateCodeSection() { b.add(getByte("0A")); b.add((byte) 0); // length is patched at the end final int functionCount = types.size(); - b.add((byte) functionCount); + b.addUnsignedInt32(functionCount); for (int i = 0; i < functionCount; i++) { - byte[] locals = localEntries.get(i); + int[] locals = localEntries.get(i); byte[] code = codeEntries.get(i); int length = 1 + locals.length + code.length; - b.add((byte) length); - b.add((byte) locals.length); - for (byte l : locals) { - b.add(l); + b.addUnsignedInt32(length); + b.addUnsignedInt32(locals.length); + for (int l : locals) { + b.addSignedInt32(l); } for (byte op : code) { b.add(op); @@ -231,10 +232,10 @@ private byte[] generateCodeSection() { private static final class BinaryExports { private final ByteArrayList types = new ByteArrayList(); - private final ByteArrayList indices = new ByteArrayList(); + private final IntArrayList indices = new IntArrayList(); private final List names = new ArrayList<>(); - private void addFunctionExport(byte functionIndex, String name) { + private void addFunctionExport(int functionIndex, String name) { types.add(getByte("00")); indices.add(functionIndex); names.add(name.getBytes(StandardCharsets.UTF_8)); @@ -244,15 +245,15 @@ private byte[] generateExportSection() { ByteArrayList b = new ByteArrayList(); b.add(getByte("07")); b.add((byte) 0); // length is patched at the end - b.add((byte) types.size()); + b.addUnsignedInt32(types.size()); for (int i = 0; i < types.size(); i++) { final byte[] name = names.get(i); - b.add((byte) name.length); + b.addUnsignedInt32(name.length); for (byte value : name) { b.add(value); } b.add(types.get(i)); - b.add(indices.get(i)); + b.addUnsignedInt32(indices.get(i)); } b.set(1, (byte) (b.size() - 2)); return b.toArray(); @@ -313,10 +314,10 @@ private byte[] generateDataSection() { private static final class BinaryGlobals { private final ByteArrayList mutabilities = new ByteArrayList(); - private final ByteArrayList valueTypes = new ByteArrayList(); + private final IntArrayList valueTypes = new IntArrayList(); private final List expressions = new ArrayList<>(); - private void add(byte mutability, byte valueType, byte[] expression) { + private void add(byte mutability, int valueType, byte[] expression) { mutabilities.add(mutability); valueTypes.add(valueType); expressions.add(expression); @@ -328,7 +329,7 @@ private byte[] generateGlobalSection() { b.add((byte) 0); // length is patched at the end b.add((byte) mutabilities.size()); for (int i = 0; i < mutabilities.size(); i++) { - b.add(valueTypes.get(i)); + b.addSignedInt32(valueTypes.get(i)); b.add(mutabilities.get(i)); for (byte e : expressions.get(i)) { b.add(e); @@ -378,8 +379,8 @@ private byte[] generateCustomSections() { final byte[] name = names.get(i); final byte[] section = sections.get(i); final int size = 1 + name.length + section.length; - b.add((byte) size); // length is patched at the end - b.add((byte) name.length); + b.addUnsignedInt32(size); + b.addUnsignedInt32(name.length); b.addRange(name, 0, name.length); b.addRange(section, 0, section.length); } @@ -400,27 +401,27 @@ protected static class BinaryBuilder { private final BinaryCustomSections binaryCustomSections = new BinaryCustomSections(); - public BinaryBuilder addType(byte[] params, byte[] results) { + public BinaryBuilder addType(int[] params, int[] results) { binaryTypes.add(params, results); return this; } - public BinaryBuilder addTable(byte initSize, byte maxSize, byte elemType) { + public BinaryBuilder addTable(int initSize, int maxSize, int elemType) { binaryTables.add(initSize, maxSize, elemType); return this; } - public BinaryBuilder addMemory(byte initSize, byte maxSize) { + public BinaryBuilder addMemory(int initSize, int maxSize) { binaryMemories.add(initSize, maxSize); return this; } - public BinaryBuilder addFunction(byte typeIndex, byte[] locals, String hexCode) { + public BinaryBuilder addFunction(int typeIndex, int[] locals, String hexCode) { binaryFunctions.add(typeIndex, locals, WasmTestUtils.hexStringToByteArray(hexCode)); return this; } - public BinaryBuilder addFunctionExport(byte functionIndex, String name) { + public BinaryBuilder addFunctionExport(int functionIndex, String name) { binaryExports.addFunctionExport(functionIndex, name); return this; } @@ -435,7 +436,7 @@ public BinaryBuilder addData(String hexCode) { return this; } - public BinaryBuilder addGlobal(byte mutability, byte valueType, String hexCode) { + public BinaryBuilder addGlobal(byte mutability, int valueType, String hexCode) { binaryGlobals.add(mutability, valueType, WasmTestUtils.hexStringToByteArray(hexCode)); return this; } diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java index 3d8e933eccd7..b2a512cdbce0 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java @@ -102,7 +102,7 @@ public class WasmJsApiSuite { private static final String REF_TYPES_OPTION = "wasm.BulkMemoryAndRefTypes"; - private static WasmFunctionInstance createWasmFunctionInstance(WasmContext context, byte[] paramTypes, byte[] resultTypes, RootNode functionRootNode) { + private static WasmFunctionInstance createWasmFunctionInstance(WasmContext context, int[] paramTypes, int[] resultTypes, RootNode functionRootNode) { WasmModule module = WasmModule.createBuiltin("dummyModule"); module.allocateFunctionType(paramTypes, resultTypes, context.getContextOptions().supportMultiValue()); WasmFunction func = module.declareFunction(0); @@ -111,7 +111,6 @@ private static WasmFunctionInstance createWasmFunctionInstance(WasmContext conte // Perform normal linking steps, incl. assignTypeEquivalenceClasses(). // Functions need to have type equivalence classes assigned for indirect calls. moduleInstance.store().linker().tryLink(moduleInstance); - assert func.typeEquivalenceClass() >= 0 : "type equivalence class must be assigned"; return new WasmFunctionInstance(moduleInstance, func, functionRootNode.getCallTarget()); } diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java index e19969a153fa..6a062a4ca2af 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java @@ -644,47 +644,47 @@ public void testDataRuntimeHeaderMinI32Length() { @Test public void testElemHeaderMin() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, 0x00}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, WasmType.FUNCREF_TYPE, 0x00}); } @Test public void testElemHeaderMinU8Count() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 1, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, 0x01}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 1, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, WasmType.FUNCREF_TYPE, 0x01}); } @Test public void testElemHeaderMaxU8Count() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 255, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 255, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, WasmType.FUNCREF_TYPE, (byte) 0xFF}); } @Test public void testElemHeaderMinU16Count() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 256, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0x80, 0x10, 0x00, 0x01}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 256, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0x80, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x01}); } @Test public void testElemHeaderMaxU16Count() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 65535, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0x80, 0x10, (byte) 0xFF, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 65535, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0x80, 0x10, WasmType.FUNCREF_TYPE, (byte) 0xFF, (byte) 0xFF}); } @Test public void testElemHeaderMinI32Count() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 65536, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0xC0, 0x10, 0x00, 0x00, 0x01, 0x00}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 65536, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{(byte) 0xC0, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x01, 0x00}); } @Test public void testElemHeaderPassive() { - test(b -> b.addElemHeader(SegmentMode.PASSIVE, 8, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x11, 0x08}); + test(b -> b.addElemHeader(SegmentMode.PASSIVE, 8, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x11, WasmType.FUNCREF_TYPE, 0x08}); } @Test public void testElemHeaderDeclarative() { - test(b -> b.addElemHeader(SegmentMode.DECLARATIVE, 8, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x12, 0x08}); + test(b -> b.addElemHeader(SegmentMode.DECLARATIVE, 8, WasmType.FUNCREF_TYPE, 0, null, -1), new byte[]{0x40, 0x12, WasmType.FUNCREF_TYPE, 0x08}); } @Test public void testElemHeaderExternref() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 8, WasmType.EXTERNREF_TYPE, 0, null, -1), new byte[]{0x40, 0x20, 0x08}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 8, WasmType.EXTERNREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, WasmType.EXTERNREF_TYPE, 0x08}); } @Test @@ -694,87 +694,87 @@ public void testElemHeaderExnref() { @Test public void testElemHeaderMinU8TableIndex() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 1, null, -1), new byte[]{0x50, 0x10, 0x00, 0x01}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 1, null, -1), new byte[]{0x50, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x01}); } @Test public void testElemHeaderMaxU8TableIndex() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 255, null, -1), new byte[]{0x50, 0x10, 0x00, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 255, null, -1), new byte[]{0x50, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF}); } @Test public void testElemHeaderMinU16TableIndex() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 256, null, -1), new byte[]{0x60, 0x10, 0x00, 0x00, 0x01}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 256, null, -1), new byte[]{0x60, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x01}); } @Test public void testElemHeaderMaxU16TableIndex() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 65535, null, -1), new byte[]{0x60, 0x10, 0x00, (byte) 0xFF, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 65535, null, -1), new byte[]{0x60, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF, (byte) 0xFF}); } @Test public void testElemHeaderMinI32TableIndex() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 65536, null, -1), new byte[]{0x70, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 65536, null, -1), new byte[]{0x70, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x00, 0x01, 0x00}); } @Test public void testElemHeaderMinU8OffsetBytecodeLength() { byte[] offsetBytecode = new byte[0]; test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, offsetBytecode, -1), - byteArrayConcat(new byte[]{0x44, 0x10, 0x00, 0x00}, offsetBytecode)); + byteArrayConcat(new byte[]{0x44, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00}, offsetBytecode)); } @Test public void testElemHeaderMaxU8OffsetBytecodeLength() { byte[] offsetBytecode = new byte[255]; test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, offsetBytecode, -1), - byteArrayConcat(new byte[]{0x44, 0x10, 0x00, (byte) 0xFF}, offsetBytecode)); + byteArrayConcat(new byte[]{0x44, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF}, offsetBytecode)); } @Test public void testElemHeaderMinU16OffsetBytecodeLength() { byte[] offsetBytecode = new byte[256]; test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, offsetBytecode, -1), - byteArrayConcat(new byte[]{0x48, 0x10, 0x00, 0x00, 0x01}, offsetBytecode)); + byteArrayConcat(new byte[]{0x48, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x01}, offsetBytecode)); } @Test public void testElemHeaderMaxU16OffsetBytecodeLength() { byte[] offsetBytecode = new byte[65535]; test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, offsetBytecode, -1), - byteArrayConcat(new byte[]{0x48, 0x10, 0x00, (byte) 0xFF, (byte) 0xFF}, offsetBytecode)); + byteArrayConcat(new byte[]{0x48, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF, (byte) 0xFF}, offsetBytecode)); } @Test public void testElemHeaderMinI32OffsetBytecodeLength() { byte[] offsetBytecode = new byte[65536]; test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, offsetBytecode, -1), - byteArrayConcat(new byte[]{0x4C, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00}, offsetBytecode)); + byteArrayConcat(new byte[]{0x4C, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x00, 0x01, 0x00}, offsetBytecode)); } @Test public void testElemHeaderMinU8OffsetAddress() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 0), new byte[]{0x41, 0x10, 0x00, 0x00}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 0), new byte[]{0x41, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00}); } @Test public void testElemHeaderMaxU8OffsetAddress() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 255), new byte[]{0x41, 0x10, 0x00, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 255), new byte[]{0x41, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF}); } @Test public void testElemHeaderMinU16OffsetAddress() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 256), new byte[]{0x42, 0x10, 0x00, 0x00, 0x01}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 256), new byte[]{0x42, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x01}); } @Test public void testElemHeaderMaxU16OffsetAddress() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 65535), new byte[]{0x42, 0x10, 0x00, (byte) 0xFF, (byte) 0xFF}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 65535), new byte[]{0x42, 0x10, WasmType.FUNCREF_TYPE, 0x00, (byte) 0xFF, (byte) 0xFF}); } @Test public void testElemHeaderMinI32OffsetAddress() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 65536), new byte[]{0x43, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 0, WasmType.FUNCREF_TYPE, 0, null, 65536), new byte[]{0x43, 0x10, WasmType.FUNCREF_TYPE, 0x00, 0x00, 0x00, 0x01, 0x00}); } @Test diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/debugging/DebugValidationSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/debugging/DebugValidationSuite.java index 3063bf52f758..d4ec02972d01 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/debugging/DebugValidationSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/debugging/DebugValidationSuite.java @@ -59,7 +59,7 @@ public class DebugValidationSuite extends AbstractBinarySuite { private static AbstractBinarySuite.BinaryBuilder getDefaultDebugBuilder() { - return newBuilder().addType(EMPTY_BYTES, EMPTY_BYTES).addFunction((byte) 0, EMPTY_BYTES, "0B").addFunctionExport((byte) 0, "_main"); + return newBuilder().addType(EMPTY_INTS, EMPTY_INTS).addFunction(0, EMPTY_INTS, "0B").addFunctionExport(0, "_main"); } @Test diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java index e5b22eb4fd84..8488708c5abb 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java @@ -75,9 +75,9 @@ private static AbstractBinarySuite.BinaryBuilder getDefaultTableInitBuilder(Stri // ;; main hex code // ) // ) - return newBuilder().addType(EMPTY_BYTES, new byte[]{WasmType.I32_TYPE}).addTable((byte) 3, (byte) 3, WasmType.FUNCREF_TYPE).addFunction((byte) 0, EMPTY_BYTES, "41 00 0B").addFunction((byte) 0, - EMPTY_BYTES, "41 01 0B").addFunction((byte) 0, EMPTY_BYTES, "41 02 0B").addFunction((byte) 0, EMPTY_BYTES, mainHexCode).addFunctionExport( - (byte) 3, "main"); + return newBuilder().addType(EMPTY_INTS, new int[]{WasmType.I32_TYPE}).addTable(3, 3, WasmType.FUNCREF_TYPE).addFunction(0, EMPTY_INTS, "41 00 0B").addFunction(0, + EMPTY_INTS, "41 01 0B").addFunction(0, EMPTY_INTS, "41 02 0B").addFunction(0, EMPTY_INTS, mainHexCode).addFunctionExport( + 3, "main"); } @Test @@ -222,7 +222,7 @@ public void testTableInitDeclarativeType3() throws IOException { // i32.const 0 // // (elem declare (ref.func 3)) - final byte[] binary = getDefaultTableInitBuilder("D2 03 1A 41 00 0B").addFunction((byte) 0, EMPTY_BYTES, "41 03 0B").addElements("03 00 01 03").build(); + final byte[] binary = getDefaultTableInitBuilder("D2 03 1A 41 00 0B").addFunction(0, EMPTY_INTS, "41 03 0B").addElements("03 00 01 03").build(); runRuntimeTest(binary, instance -> { Value main = instance.getMember("main"); Value result = main.execute(); @@ -238,7 +238,7 @@ public void testTableInitDeclarativeType7() throws IOException { // i32.const 0 // // (elem declare (ref.func 3)) - final byte[] binary = getDefaultTableInitBuilder("D2 03 1A 41 00 0B").addFunction((byte) 0, EMPTY_BYTES, "41 03 0B").addElements("07 70 01 D2 03 0B").build(); + final byte[] binary = getDefaultTableInitBuilder("D2 03 1A 41 00 0B").addFunction(0, EMPTY_INTS, "41 03 0B").addElements("07 70 01 D2 03 0B").build(); runRuntimeTest(binary, instance -> { Value main = instance.getMember("main"); Value result = main.execute(); @@ -535,15 +535,16 @@ public void testMultipleTables() throws IOException { // (table 1 1 funcref) // (table 1 1 externref) // (table 1 1 exnref) - final byte[] binary = newBuilder().addTable((byte) 1, (byte) 1, WasmType.FUNCREF_TYPE).addTable((byte) 1, (byte) 1, WasmType.EXTERNREF_TYPE).addTable((byte) 1, (byte) 1, - WasmType.EXNREF_TYPE).build(); + final byte[] binary = newBuilder().addTable(1, 1, WasmType.FUNCREF_TYPE).addTable(1, 1, WasmType.EXTERNREF_TYPE) + .addTable(1, 1, WasmType.EXNREF_TYPE).build(); runParserTest(binary, options -> options.option("wasm.Exceptions", "true"), Context::eval); + runParserTest(binary, Context::eval); } @Test public void testTableInvalidElemType() throws IOException { // (table 1 1 i32) - final byte[] binary = newBuilder().addTable((byte) 1, (byte) 1, WasmType.I32_TYPE).build(); + final byte[] binary = newBuilder().addTable(1, 1, WasmType.I32_TYPE).build(); runParserTest(binary, (context, source) -> { try { context.eval(source); @@ -564,7 +565,7 @@ private static AbstractBinarySuite.BinaryBuilder getDefaultMemoryInitBuilder(Str // ;; main hex code // ) // ) - return newBuilder().addType(EMPTY_BYTES, new byte[]{WasmType.I32_TYPE}).addMemory((byte) 1, (byte) 1).addFunction((byte) 0, EMPTY_BYTES, mainHexCode).addFunctionExport((byte) 0, "main"); + return newBuilder().addType(EMPTY_INTS, new int[]{WasmType.I32_TYPE}).addMemory(1, 1).addFunction(0, EMPTY_INTS, mainHexCode).addFunctionExport(0, "main"); } @Test @@ -875,8 +876,8 @@ public void testGlobalWithNull() throws IOException { // (func (export "main") (type 0) // global.get 0 // ) - final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.EXTERNREF_TYPE, "D0 6F 0B").addType(EMPTY_BYTES, new byte[]{WasmType.EXTERNREF_TYPE}).addFunction((byte) 0, - EMPTY_BYTES, "23 00 0B").addFunctionExport((byte) 0, "main").build(); + final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.EXTERNREF_TYPE, "D0 6F 0B").addType(EMPTY_INTS, new int[]{WasmType.EXTERNREF_TYPE}).addFunction(0, + EMPTY_INTS, "23 00 0B").addFunctionExport(0, "main").build(); runRuntimeTest(binary, instance -> { Value main = instance.getMember("main"); Value result = main.execute(); @@ -891,8 +892,8 @@ public void testGlobalWithNullException() throws IOException { // (func (export "main") (type 0) // global.get 0 // ) - final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.EXNREF_TYPE, "D0 69 0B").addType(EMPTY_BYTES, new byte[]{WasmType.EXNREF_TYPE}).addFunction((byte) 0, - EMPTY_BYTES, "23 00 0B").addFunctionExport((byte) 0, "main").build(); + final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.EXNREF_TYPE, "D0 69 0B").addType(EMPTY_INTS, new int[]{WasmType.EXNREF_TYPE}).addFunction(0, + EMPTY_INTS, "23 00 0B").addFunctionExport(0, "main").build(); runRuntimeTest(binary, options -> options.option("wasm.Exceptions", "true"), instance -> { Value main = instance.getMember("main"); Value result = main.execute(); @@ -915,9 +916,9 @@ public void testGlobalWithFunction() throws IOException { // i32.const 0 // call_indirect 0 (type 0) // ) - final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.FUNCREF_TYPE, "D2 00 0B").addTable((byte) 1, (byte) 1, WasmType.FUNCREF_TYPE).addType(EMPTY_BYTES, - new byte[]{WasmType.I32_TYPE}).addFunction((byte) 0, EMPTY_BYTES, "41 01 0B").addFunction((byte) 0, EMPTY_BYTES, "41 00 23 00 26 00 41 00 11 00 00 0B").addFunctionExport( - (byte) 1, "main").build(); + final byte[] binary = newBuilder().addGlobal(GlobalModifier.CONSTANT, WasmType.FUNCREF_TYPE, "D2 00 0B").addTable(1, 1, WasmType.FUNCREF_TYPE).addType(EMPTY_INTS, + new int[]{WasmType.I32_TYPE}).addFunction(0, EMPTY_INTS, "41 01 0B").addFunction(0, EMPTY_INTS, "41 00 23 00 26 00 41 00 11 00 00 0B").addFunctionExport( + 1, "main").build(); runRuntimeTest(binary, instance -> { Value main = instance.getMember("main"); Value result = main.execute(); @@ -956,13 +957,13 @@ public void testMultiValueWithReferenceTypes() throws IOException { // table.set 1 // ) // (elem (table 0) (ref.func 0) (ref.func 1)) - final byte[] binary = newBuilder().addTable((byte) 2, (byte) 2, WasmType.FUNCREF_TYPE).addTable((byte) 2, (byte) 2, WasmType.EXTERNREF_TYPE).addType(EMPTY_BYTES, - new byte[]{WasmType.I32_TYPE}).addType(new byte[]{WasmType.I32_TYPE, WasmType.FUNCREF_TYPE}, new byte[]{WasmType.EXTERNREF_TYPE, WasmType.FUNCREF_TYPE}).addType(EMPTY_BYTES, - new byte[]{WasmType.EXTERNREF_TYPE, WasmType.FUNCREF_TYPE}).addType(new byte[]{WasmType.I32_TYPE, WasmType.EXTERNREF_TYPE}, EMPTY_BYTES).addFunction((byte) 0, - EMPTY_BYTES, "41 01 0B").addFunction((byte) 0, EMPTY_BYTES, "41 02 0B").addFunction((byte) 1, EMPTY_BYTES, "20 00 25 01 20 01 0B").addFunction( - (byte) 2, EMPTY_BYTES, "41 01 41 00 25 00 10 02 0B").addFunction((byte) 3, EMPTY_BYTES, - "20 00 20 01 26 01 0B").addFunctionExport((byte) 3, - "main").addFunctionExport((byte) 4, "setRef").addElements("00 41 00 0B 02 00 01").build(); + final byte[] binary = newBuilder().addTable(2, 2, WasmType.FUNCREF_TYPE).addTable(2, 2, WasmType.EXTERNREF_TYPE).addType(EMPTY_INTS, + new int[]{WasmType.I32_TYPE}).addType(new int[]{WasmType.I32_TYPE, WasmType.FUNCREF_TYPE}, new int[]{WasmType.EXTERNREF_TYPE, WasmType.FUNCREF_TYPE}).addType(EMPTY_INTS, + new int[]{WasmType.EXTERNREF_TYPE, WasmType.FUNCREF_TYPE}).addType(new int[]{WasmType.I32_TYPE, WasmType.EXTERNREF_TYPE}, EMPTY_INTS).addFunction(0, + EMPTY_INTS, "41 01 0B").addFunction(0, EMPTY_INTS, "41 02 0B").addFunction(1, EMPTY_INTS, "20 00 25 01 20 01 0B").addFunction( + 2, EMPTY_INTS, "41 01 41 00 25 00 10 02 0B").addFunction(3, EMPTY_INTS, + "20 00 20 01 26 01 0B").addFunctionExport(3, + "main").addFunctionExport(4, "setRef").addElements("00 41 00 0B 02 00 01").build(); runRuntimeTest(binary, instance -> { Value setRef = instance.getMember("setRef"); setRef.execute(0, "foo"); diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java index 1e4c35c44603..03c6ba46e99d 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java @@ -167,7 +167,7 @@ public static Collection data() { Failure.Type.MALFORMED), binaryCase( "Table - import with invalid elemtype", - "Invalid element type for table import: 0x6F should = 0x70", + "Invalid element type for table import: -17 should = -16", // (import "a" "b" (table 0 1 externref)) "00 61 73 6D 01 00 00 00 02 09 01 01 61 01 62 01 6F 00 01", Failure.Type.MALFORMED), diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index c2609ef80789..71664b4c55cf 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -49,15 +49,20 @@ import static org.graalvm.wasm.Assert.assertUnsignedLongLessOrEqual; import static org.graalvm.wasm.Assert.fail; import static org.graalvm.wasm.WasmType.EXNREF_TYPE; +import static org.graalvm.wasm.WasmType.EXN_HEAPTYPE; import static org.graalvm.wasm.WasmType.EXTERNREF_TYPE; +import static org.graalvm.wasm.WasmType.EXTERN_HEAPTYPE; import static org.graalvm.wasm.WasmType.F32_TYPE; import static org.graalvm.wasm.WasmType.F64_TYPE; import static org.graalvm.wasm.WasmType.FUNCREF_TYPE; +import static org.graalvm.wasm.WasmType.FUNC_HEAPTYPE; import static org.graalvm.wasm.WasmType.I32_TYPE; import static org.graalvm.wasm.WasmType.I64_TYPE; import static org.graalvm.wasm.WasmType.NULL_TYPE; +import static org.graalvm.wasm.WasmType.REF_NULL_TYPE_HEADER; +import static org.graalvm.wasm.WasmType.REF_TYPE_HEADER; import static org.graalvm.wasm.WasmType.V128_TYPE; -import static org.graalvm.wasm.WasmType.VOID_TYPE; +import static org.graalvm.wasm.WasmType.VOID_BLOCK_TYPE; import static org.graalvm.wasm.constants.Bytecode.vectorOpcodeToBytecode; import static org.graalvm.wasm.constants.Sizes.MAX_MEMORY_64_DECLARATION_SIZE; import static org.graalvm.wasm.constants.Sizes.MAX_MEMORY_DECLARATION_SIZE; @@ -76,7 +81,7 @@ import org.graalvm.collections.Pair; import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.api.Vector128Shape; -import org.graalvm.wasm.collection.ByteArrayList; +import org.graalvm.wasm.collection.IntArrayList; import org.graalvm.wasm.constants.Bytecode; import org.graalvm.wasm.constants.BytecodeBitEncoding; import org.graalvm.wasm.constants.ExceptionHandlerType; @@ -426,9 +431,9 @@ private void readImportSection() { break; } case ImportIdentifier.TABLE: { - final byte elemType = readRefType(exceptions); + final int elemType = readRefType(exceptions); if (!bulkMemoryAndRefTypes) { - assertByteEqual(elemType, FUNCREF_TYPE, "Invalid element type for table import", Failure.UNSPECIFIED_MALFORMED); + assertIntEqual(elemType, FUNCREF_TYPE, Failure.UNSPECIFIED_MALFORMED, "Invalid element type for table import"); } readTableLimits(multiResult); final int tableIndex = module.tableCount(); @@ -444,7 +449,7 @@ private void readImportSection() { break; } case ImportIdentifier.GLOBAL: { - byte type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + int type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); byte mutability = readMutability(); int globalIndex = module.symbolTable().numGlobals(); module.symbolTable().importGlobal(moduleName, memberName, globalIndex, type, mutability); @@ -483,7 +488,7 @@ private void readTableSection() { module.limits().checkTableCount(startingTableIndex + tableCount); for (int tableIndex = startingTableIndex; tableIndex != startingTableIndex + tableCount; tableIndex++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); - final byte elemType = readRefType(exceptions); + final int elemType = readRefType(exceptions); readTableLimits(multiResult); module.symbolTable().allocateTable(tableIndex, multiResult[0], multiResult[1], elemType, bulkMemoryAndRefTypes); } @@ -516,7 +521,7 @@ private void readCodeSection(RuntimeBytecodeGen bytecode, BytecodeGen functionDe final int codeEntrySize = readUnsignedInt32(); final int startOffset = offset; module.limits().checkFunctionSize(codeEntrySize); - final ByteArrayList locals = readCodeEntryLocals(); + final IntArrayList locals = readCodeEntryLocals(); final int localCount = locals.size() + module.function(importedFunctionCount + entryIndex).paramCount(); module.limits().checkLocalCount(localCount); // Store the function start offset, instruction start offset, and function end offset. @@ -529,33 +534,33 @@ private void readCodeSection(RuntimeBytecodeGen bytecode, BytecodeGen functionDe module.setCodeEntries(codeEntries); } - private CodeEntry readCodeEntry(int functionIndex, ByteArrayList locals, int endOffset, boolean hasNextFunction, RuntimeBytecodeGen bytecode, int codeEntryIndex) { + private CodeEntry readCodeEntry(int functionIndex, IntArrayList locals, int endOffset, boolean hasNextFunction, RuntimeBytecodeGen bytecode, int codeEntryIndex) { final WasmFunction function = module.symbolTable().function(functionIndex); int paramCount = function.paramCount(); - byte[] localTypes = new byte[function.paramCount() + locals.size()]; + int[] localTypes = new int[function.paramCount() + locals.size()]; for (int index = 0; index != paramCount; index++) { localTypes[index] = function.paramTypeAt(index); } for (int index = 0; index != locals.size(); index++) { localTypes[index + paramCount] = locals.get(index); } - byte[] resultTypes = new byte[function.resultCount()]; + int[] resultTypes = new int[function.resultCount()]; for (int index = 0; index != resultTypes.length; index++) { resultTypes[index] = function.resultTypeAt(index); } return readFunction(functionIndex, localTypes, resultTypes, endOffset, hasNextFunction, bytecode, codeEntryIndex, null); } - private ByteArrayList readCodeEntryLocals() { + private IntArrayList readCodeEntryLocals() { final int localsGroupCount = readLength(); - final ByteArrayList localTypes = new ByteArrayList(); + final IntArrayList localTypes = new IntArrayList(); int localsLength = 0; for (int localGroup = 0; localGroup != localsGroupCount; localGroup++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); final int groupLength = readUnsignedInt32(); localsLength += groupLength; module.limits().checkLocalCount(localsLength); - final byte t = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + final int t = readValueType(bulkMemoryAndRefTypes, simd, exceptions); for (int i = 0; i != groupLength; ++i) { localTypes.add(t); } @@ -563,27 +568,26 @@ private ByteArrayList readCodeEntryLocals() { return localTypes; } - private byte[] extractBlockParamTypes(int typeIndex) { + private int[] extractBlockParamTypes(int typeIndex) { int paramCount = module.functionTypeParamCount(typeIndex); - byte[] params = new byte[paramCount]; + int[] params = new int[paramCount]; for (int i = 0; i < paramCount; i++) { params[i] = module.functionTypeParamTypeAt(typeIndex, i); } return params; } - private byte[] extractBlockResultTypes(int typeIndex) { + private int[] extractBlockResultTypes(int typeIndex) { int resultCount = module.functionTypeResultCount(typeIndex); - byte[] results = new byte[resultCount]; + int[] results = new int[resultCount]; for (int i = 0; i < resultCount; i++) { results[i] = module.functionTypeResultTypeAt(typeIndex, i); } return results; } - private static byte[] encapsulateResultType(int type) { + private static int[] encapsulateResultType(int type) { return switch (type) { - case VOID_TYPE -> WasmType.VOID_TYPE_ARRAY; case I32_TYPE -> WasmType.I32_TYPE_ARRAY; case I64_TYPE -> WasmType.I64_TYPE_ARRAY; case F32_TYPE -> WasmType.F32_TYPE_ARRAY; @@ -592,11 +596,11 @@ private static byte[] encapsulateResultType(int type) { case FUNCREF_TYPE -> WasmType.FUNCREF_TYPE_ARRAY; case EXTERNREF_TYPE -> WasmType.EXTERNREF_TYPE_ARRAY; case EXNREF_TYPE -> WasmType.EXNREF_TYPE_ARRAY; - default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL); + default -> new int[]{type}; }; } - private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTypes, int sourceCodeEndOffset, boolean hasNextFunction, RuntimeBytecodeGen bytecode, + private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultTypes, int sourceCodeEndOffset, boolean hasNextFunction, RuntimeBytecodeGen bytecode, int codeEntryIndex, EconomicMap offsetToLineIndexMap) { final ParserState state = new ParserState(bytecode); final ArrayList callNodes = new ArrayList<>(); @@ -623,20 +627,26 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy state.addInstruction(Bytecode.NOP); break; case Instructions.BLOCK: { - final byte[] blockParamTypes; - final byte[] blockResultTypes; + final int[] blockParamTypes; + final int[] blockResultTypes; readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); // Extract value based on result arity. - if (multiResult[1] == SINGLE_RESULT_VALUE) { - blockParamTypes = WasmType.VOID_TYPE_ARRAY; - blockResultTypes = encapsulateResultType(multiResult[0]); - } else if (multiValue) { - int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); - blockParamTypes = extractBlockParamTypes(typeIndex); - blockResultTypes = extractBlockResultTypes(typeIndex); - } else { - throw WasmException.create(Failure.DISABLED_MULTI_VALUE); + switch (multiResult[1]) { + case BLOCK_TYPE_VOID -> { + blockParamTypes = WasmType.VOID_TYPE_ARRAY; + blockResultTypes = WasmType.VOID_TYPE_ARRAY; + } + case BLOCK_TYPE_VALTYPE -> { + blockParamTypes = WasmType.VOID_TYPE_ARRAY; + blockResultTypes = encapsulateResultType(multiResult[0]); + } + case BLOCK_TYPE_TYPE_INDEX -> { + int typeIndex = multiResult[0]; + state.checkFunctionTypeExists(typeIndex, module.typeCount()); + blockParamTypes = extractBlockParamTypes(typeIndex); + blockResultTypes = extractBlockResultTypes(typeIndex); + } + default -> throw WasmException.create(Failure.DISABLED_MULTI_VALUE); } state.popAll(blockParamTypes); state.enterBlock(blockParamTypes, blockResultTypes); @@ -644,20 +654,26 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy } case Instructions.LOOP: { // Jumps are targeting the loop instruction for OSR. - final byte[] loopParamTypes; - final byte[] loopResultTypes; + final int[] loopParamTypes; + final int[] loopResultTypes; readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); // Extract value based on result arity. - if (multiResult[1] == SINGLE_RESULT_VALUE) { - loopParamTypes = WasmType.VOID_TYPE_ARRAY; - loopResultTypes = encapsulateResultType(multiResult[0]); - } else if (multiValue) { - int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); - loopParamTypes = extractBlockParamTypes(typeIndex); - loopResultTypes = extractBlockResultTypes(typeIndex); - } else { - throw WasmException.create(Failure.DISABLED_MULTI_VALUE); + switch (multiResult[1]) { + case BLOCK_TYPE_VOID -> { + loopParamTypes = WasmType.VOID_TYPE_ARRAY; + loopResultTypes = WasmType.VOID_TYPE_ARRAY; + } + case BLOCK_TYPE_VALTYPE -> { + loopParamTypes = WasmType.VOID_TYPE_ARRAY; + loopResultTypes = encapsulateResultType(multiResult[0]); + } + case BLOCK_TYPE_TYPE_INDEX -> { + int typeIndex = multiResult[0]; + state.checkFunctionTypeExists(typeIndex, module.typeCount()); + loopParamTypes = extractBlockParamTypes(typeIndex); + loopResultTypes = extractBlockResultTypes(typeIndex); + } + default -> throw WasmException.create(Failure.DISABLED_MULTI_VALUE); } state.popAll(loopParamTypes); state.enterLoop(loopParamTypes, loopResultTypes); @@ -665,27 +681,33 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy } case Instructions.IF: { state.popChecked(I32_TYPE); // condition - final byte[] ifParamTypes; - final byte[] ifResultTypes; + final int[] ifParamTypes; + final int[] ifResultTypes; readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); // Extract value based on result arity. - if (multiResult[1] == SINGLE_RESULT_VALUE) { - ifParamTypes = WasmType.VOID_TYPE_ARRAY; - ifResultTypes = encapsulateResultType(multiResult[0]); - } else if (multiValue) { - int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); - ifParamTypes = extractBlockParamTypes(typeIndex); - ifResultTypes = extractBlockResultTypes(typeIndex); - } else { - throw WasmException.create(Failure.DISABLED_MULTI_VALUE); + switch (multiResult[1]) { + case BLOCK_TYPE_VOID -> { + ifParamTypes = WasmType.VOID_TYPE_ARRAY; + ifResultTypes = WasmType.VOID_TYPE_ARRAY; + } + case BLOCK_TYPE_VALTYPE -> { + ifParamTypes = WasmType.VOID_TYPE_ARRAY; + ifResultTypes = encapsulateResultType(multiResult[0]); + } + case BLOCK_TYPE_TYPE_INDEX -> { + int typeIndex = multiResult[0]; + state.checkFunctionTypeExists(typeIndex, module.typeCount()); + ifParamTypes = extractBlockParamTypes(typeIndex); + ifResultTypes = extractBlockResultTypes(typeIndex); + } + default -> throw WasmException.create(Failure.DISABLED_MULTI_VALUE); } state.popAll(ifParamTypes); state.enterIf(ifParamTypes, ifResultTypes); break; } case Instructions.END: { - final byte[] endResultTypes = state.exit(multiValue); + final int[] endResultTypes = state.exit(multiValue); state.pushAll(endResultTypes); if (state.controlStackSize() == 0) { /* @@ -751,7 +773,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy // Pop parameters final WasmFunction function = module.function(callFunctionIndex); - byte[] params = new byte[function.paramCount()]; + int[] params = new int[function.paramCount()]; for (int i = function.paramCount() - 1; i >= 0; --i) { params[i] = function.paramTypeAt(i); } @@ -772,7 +794,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy // Pop the function index to call state.popChecked(I32_TYPE); state.checkFunctionTypeExists(expectedFunctionTypeIndex, module.typeCount()); - assertByteEqual(FUNCREF_TYPE, module.tableElementType(tableIndex), Failure.TYPE_MISMATCH); + assertIntEqual(FUNCREF_TYPE, module.tableElementType(tableIndex), Failure.TYPE_MISMATCH); // Pop parameters for (int i = module.functionTypeParamCount(expectedFunctionTypeIndex) - 1; i >= 0; --i) { @@ -783,7 +805,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy if (!multiValue) { assertIntLessOrEqual(resultCount, 1, Failure.INVALID_RESULT_ARITY); } - byte[] callResultTypes = new byte[resultCount]; + int[] callResultTypes = new int[resultCount]; for (int i = 0; i < resultCount; i++) { callResultTypes[i] = module.functionTypeResultTypeAt(expectedFunctionTypeIndex, i); } @@ -793,7 +815,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy break; } case Instructions.DROP: - final byte type = state.pop(); + final int type = state.pop(); if (WasmType.isNumberType(type)) { state.addInstruction(Bytecode.DROP); } else { @@ -802,11 +824,11 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy break; case Instructions.SELECT: { state.popChecked(I32_TYPE); // condition - final byte t1 = state.pop(); // first operand - final byte t2 = state.pop(); // second operand + final int t1 = state.pop(); // first operand + final int t2 = state.pop(); // second operand assertTrue((WasmType.isNumberType(t1) || WasmType.isVectorType(t1)) && (WasmType.isNumberType(t2) || WasmType.isVectorType(t2)), Failure.TYPE_MISMATCH); assertTrue(t1 == t2 || t1 == WasmType.UNKNOWN_TYPE || t2 == WasmType.UNKNOWN_TYPE, Failure.TYPE_MISMATCH); - final byte t = t1 == WasmType.UNKNOWN_TYPE ? t2 : t1; + final int t = t1 == WasmType.UNKNOWN_TYPE ? t2 : t1; state.push(t); if (WasmType.isNumberType(t)) { state.addSelectInstruction(Bytecode.SELECT); @@ -819,7 +841,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy checkBulkMemoryAndRefTypesSupport(opcode); final int length = readLength(); assertIntEqual(length, 1, Failure.INVALID_RESULT_ARITY); - final byte t = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + final int t = readValueType(bulkMemoryAndRefTypes, simd, exceptions); state.popChecked(I32_TYPE); state.popChecked(t); state.popChecked(t); @@ -833,20 +855,26 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy } case Instructions.TRY_TABLE: { checkExceptionHandlingSupport(opcode); - final byte[] tryTableParamTypes; - final byte[] tryTableResultTypes; + final int[] tryTableParamTypes; + final int[] tryTableResultTypes; readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); - - if (multiResult[1] == SINGLE_RESULT_VALUE) { - tryTableParamTypes = WasmType.VOID_TYPE_ARRAY; - tryTableResultTypes = encapsulateResultType(multiResult[0]); - } else if (multiValue) { - final int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); - tryTableParamTypes = extractBlockParamTypes(typeIndex); - tryTableResultTypes = extractBlockResultTypes(typeIndex); - } else { - throw WasmException.create(Failure.DISABLED_MULTI_VALUE); + // Extract value based on result arity. + switch (multiResult[1]) { + case BLOCK_TYPE_VOID -> { + tryTableParamTypes = WasmType.VOID_TYPE_ARRAY; + tryTableResultTypes = WasmType.VOID_TYPE_ARRAY; + } + case BLOCK_TYPE_VALTYPE -> { + tryTableParamTypes = WasmType.VOID_TYPE_ARRAY; + tryTableResultTypes = encapsulateResultType(multiResult[0]); + } + case BLOCK_TYPE_TYPE_INDEX -> { + int typeIndex = multiResult[0]; + state.checkFunctionTypeExists(typeIndex, module.typeCount()); + tryTableParamTypes = extractBlockParamTypes(typeIndex); + tryTableResultTypes = extractBlockResultTypes(typeIndex); + } + default -> throw WasmException.create(Failure.DISABLED_MULTI_VALUE); } state.popAll(tryTableParamTypes); final ExceptionHandler[] handlers = readExceptionHandlers(state); @@ -857,7 +885,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy checkExceptionHandlingSupport(opcode); final int tagIndex = readTagIndex(); final int typeIndex = module.tagTypeIndex(tagIndex); - final byte[] paramTypes = module.typeAt(typeIndex).paramTypes(); + final int[] paramTypes = module.typeAt(typeIndex).paramTypes(); state.popAll(paramTypes); state.addMiscFlag(); state.addInstruction(Bytecode.THROW, tagIndex); @@ -877,7 +905,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy case Instructions.LOCAL_GET: { final int localIndex = readLocalIndex(); assertUnsignedIntLess(localIndex, locals.length, Failure.UNKNOWN_LOCAL); - final byte localType = locals[localIndex]; + final int localType = locals[localIndex]; state.push(localType); if (WasmType.isNumberType(localType)) { state.addUnsignedInstruction(Bytecode.LOCAL_GET_U8, localIndex); @@ -889,7 +917,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy case Instructions.LOCAL_SET: { final int localIndex = readLocalIndex(); assertUnsignedIntLess(localIndex, locals.length, Failure.UNKNOWN_LOCAL); - final byte localType = locals[localIndex]; + final int localType = locals[localIndex]; state.popChecked(localType); if (WasmType.isNumberType(localType)) { state.addUnsignedInstruction(Bytecode.LOCAL_SET_U8, localIndex); @@ -901,7 +929,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy case Instructions.LOCAL_TEE: { final int localIndex = readLocalIndex(); assertUnsignedIntLess(localIndex, locals.length, Failure.UNKNOWN_LOCAL); - final byte localType = locals[localIndex]; + final int localType = locals[localIndex]; state.popChecked(localType); state.push(localType); if (WasmType.isNumberType(localType)) { @@ -929,7 +957,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy case Instructions.TABLE_GET: { checkBulkMemoryAndRefTypesSupport(opcode); final int index = readTableIndex(); - final byte elementType = module.tableElementType(index); + final int elementType = module.tableElementType(index); state.popChecked(I32_TYPE); state.push(elementType); state.addInstruction(Bytecode.TABLE_GET, index); @@ -938,7 +966,7 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy case Instructions.TABLE_SET: { checkBulkMemoryAndRefTypesSupport(opcode); final int index = readTableIndex(); - final byte elementType = module.tableElementType(index); + final int elementType = module.tableElementType(index); state.popChecked(elementType); state.popChecked(I32_TYPE); state.addInstruction(Bytecode.TABLE_SET, index); @@ -1116,14 +1144,14 @@ private CodeEntry readFunction(int functionIndex, byte[] locals, byte[] resultTy final int functionEndOffset = bytecode.location(); bytecode.addCodeEntry(functionIndex, state.maxStackSize(), bytecodeEndOffset - bytecodeStartOffset, locals.length, resultTypes.length); - for (byte local : locals) { - bytecode.addByte(local); + for (int local : locals) { + bytecode.addType(local); } if (locals.length != 0) { bytecode.addByte((byte) 0); } - for (byte result : resultTypes) { - bytecode.addByte(result); + for (int result : resultTypes) { + bytecode.addType(result); } if (resultTypes.length != 0) { bytecode.addByte((byte) 0); @@ -1530,7 +1558,7 @@ private void readNumericInstructions(ParserState state, int opcode) { final int elementIndex = readUnsignedInt32(); final int tableIndex = readTableIndex(); module.checkElemIndex(elementIndex); - final byte elementType = module.tableElementType(tableIndex); + final int elementType = module.tableElementType(tableIndex); module.checkElemType(elementIndex, elementType); state.popChecked(I32_TYPE); state.popChecked(I32_TYPE); @@ -1550,10 +1578,10 @@ private void readNumericInstructions(ParserState state, int opcode) { case Instructions.TABLE_COPY: checkBulkMemoryAndRefTypesSupport(miscOpcode); final int destinationTableIndex = readTableIndex(); - final byte destinationElementType = module.tableElementType(destinationTableIndex); + final int destinationElementType = module.tableElementType(destinationTableIndex); final int sourceTableIndex = readTableIndex(); - final byte sourceElementType = module.tableElementType(sourceTableIndex); - assertByteEqual(sourceElementType, destinationElementType, Failure.TYPE_MISMATCH); + final int sourceElementType = module.tableElementType(sourceTableIndex); + assertIntEqual(sourceElementType, destinationElementType, Failure.TYPE_MISMATCH); state.popChecked(I32_TYPE); state.popChecked(I32_TYPE); state.popChecked(I32_TYPE); @@ -1571,7 +1599,7 @@ private void readNumericInstructions(ParserState state, int opcode) { case Instructions.TABLE_GROW: { checkBulkMemoryAndRefTypesSupport(miscOpcode); final int tableIndex = readTableIndex(); - final byte elementType = module.tableElementType(tableIndex); + final int elementType = module.tableElementType(tableIndex); state.popChecked(I32_TYPE); state.popChecked(elementType); state.push(I32_TYPE); @@ -1582,7 +1610,7 @@ private void readNumericInstructions(ParserState state, int opcode) { case Instructions.TABLE_FILL: { checkBulkMemoryAndRefTypesSupport(miscOpcode); final int tableIndex = readTableIndex(); - final byte elementType = module.tableElementType(tableIndex); + final int elementType = module.tableElementType(tableIndex); state.popChecked(I32_TYPE); state.popChecked(elementType); state.popChecked(I32_TYPE); @@ -1611,7 +1639,7 @@ private void readNumericInstructions(ParserState state, int opcode) { break; case Instructions.REF_NULL: checkBulkMemoryAndRefTypesSupport(opcode); - final byte type = readRefType(exceptions); + final int type = readRefType(exceptions); state.push(type); state.addInstruction(Bytecode.REF_NULL); break; @@ -2384,7 +2412,7 @@ private void checkExceptionHandlingSupport(int opcode) { checkContextOption(wasmContext.getContextOptions().supportExceptions(), "Exception handling is not enabled (opcode: 0x%02x)", opcode); } - private void store(ParserState state, byte type, int n, long[] result) { + private void store(ParserState state, int type, int n, long[] result) { int alignHint = readAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2398,7 +2426,7 @@ private void store(ParserState state, byte type, int n, long[] result) { result[1] = memoryOffset; } - private void load(ParserState state, byte type, int n, long[] result) { + private void load(ParserState state, int type, int n, long[] result) { final int alignHint = readAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2412,7 +2440,7 @@ private void load(ParserState state, byte type, int n, long[] result) { result[1] = memoryOffset; } - private void atomicStore(ParserState state, byte type, int n, long[] result) { + private void atomicStore(ParserState state, int type, int n, long[] result) { int alignHint = readAtomicAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2426,7 +2454,7 @@ private void atomicStore(ParserState state, byte type, int n, long[] result) { result[1] = memoryOffset; } - private void atomicLoad(ParserState state, byte type, int n, long[] result) { + private void atomicLoad(ParserState state, int type, int n, long[] result) { final int alignHint = readAtomicAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2440,7 +2468,7 @@ private void atomicLoad(ParserState state, byte type, int n, long[] result) { result[1] = memoryOffset; } - private void atomicReadModifyWrite(ParserState state, byte type, int n, long[] result) { + private void atomicReadModifyWrite(ParserState state, int type, int n, long[] result) { final int alignHint = readAtomicAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2455,7 +2483,7 @@ private void atomicReadModifyWrite(ParserState state, byte type, int n, long[] r result[1] = memoryOffset; } - private void atomicCompareExchange(ParserState state, byte type, int n, long[] result) { + private void atomicCompareExchange(ParserState state, int type, int n, long[] result) { final int alignHint = readAtomicAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2486,7 +2514,7 @@ private void atomicNotify(ParserState state, long[] result) { result[1] = memoryOffset; } - private void atomicWait(ParserState state, byte type, int n, long[] result) { + private void atomicWait(ParserState state, int type, int n, long[] result) { final int alignHint = readAtomicAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); final long memoryOffset = readBaseMemoryOffset(); @@ -2514,7 +2542,7 @@ private ExceptionHandler[] readExceptionHandlers(ParserState state) { final int label = readUnsignedInt32(); assertUnsignedIntLess(label, state.controlStackSize(), Failure.INVALID_CATCH_CLAUSE_LABEL); final int typeIndex = module.tagTypeIndex(tag); - final byte[] paramTypes = module.typeAt(typeIndex).paramTypes(); + final int[] paramTypes = module.typeAt(typeIndex).paramTypes(); handlers[i] = state.enterCatchClause(opcode, tag, label); state.pushAll(paramTypes); } @@ -2523,7 +2551,7 @@ private ExceptionHandler[] readExceptionHandlers(ParserState state) { final int label = readUnsignedInt32(); assertUnsignedIntLess(label, state.controlStackSize(), Failure.INVALID_CATCH_CLAUSE_LABEL); final int typeIndex = module.tagTypeIndex(tag); - final byte[] paramTypes = module.typeAt(typeIndex).paramTypes(); + final int[] paramTypes = module.typeAt(typeIndex).paramTypes(); handlers[i] = state.enterCatchClause(opcode, tag, label); state.pushAll(paramTypes); state.push(EXNREF_TYPE); @@ -2565,7 +2593,7 @@ private Pair readLongOffsetExpression() { } } - private Pair readConstantExpression(byte resultType, boolean onlyImportedGlobals) { + private Pair readConstantExpression(int resultType, boolean onlyImportedGlobals) { // Read the constant expression. // https://webassembly.github.io/spec/core/valid/instructions.html#constant-expressions final RuntimeBytecodeGen bytecode = new RuntimeBytecodeGen(); @@ -2574,7 +2602,7 @@ private Pair readConstantExpression(byte resultType, boolean onl final List stack = new ArrayList<>(); boolean calculable = true; - state.enterFunction(new byte[]{resultType}); + state.enterFunction(new int[]{resultType}); int opcode; while ((opcode = read1() & 0xFF) != Instructions.END) { switch (opcode) { @@ -2618,7 +2646,7 @@ private Pair readConstantExpression(byte resultType, boolean onl } case Instructions.REF_NULL: checkBulkMemoryAndRefTypesSupport(opcode); - final byte type = readRefType(exceptions); + final int type = readRefType(exceptions); state.push(type); state.addInstruction(Bytecode.REF_NULL); if (calculable) { @@ -2741,7 +2769,7 @@ private void checkElemKind() { } } - private long[] readElemExpressions(byte elemType) { + private long[] readElemExpressions(int elemType) { final int expressionCount = readLength(); final long[] functionIndices = new long[expressionCount]; for (int index = 0; index != expressionCount; index++) { @@ -2765,7 +2793,7 @@ private long[] readElemExpressions(byte elemType) { throw WasmException.format(Failure.ILLEGAL_OPCODE, "Illegal opcode for constant expression: 0x%02X", opcode); } case Instructions.REF_NULL: - final byte type = readRefType(exceptions); + final int type = readRefType(exceptions); if (bulkMemoryAndRefTypes && type != elemType) { fail(Failure.TYPE_MISMATCH, "Invalid ref.null type: 0x%02X", type); } @@ -2782,8 +2810,8 @@ private long[] readElemExpressions(byte elemType) { case Instructions.GLOBAL_GET: final int globalIndex = readGlobalIndex(); assertIntEqual(module.globalMutability(globalIndex), GlobalModifier.CONSTANT, Failure.CONSTANT_EXPRESSION_REQUIRED); - final byte valueType = module.globalValueType(globalIndex); - assertByteEqual(valueType, elemType, Failure.TYPE_MISMATCH); + final int valueType = module.globalValueType(globalIndex); + assertIntEqual(valueType, elemType, Failure.TYPE_MISMATCH); functionIndices[index] = ((long) I32_TYPE << 32) | globalIndex; break; case Instructions.VECTOR: @@ -2813,7 +2841,7 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { final byte[] currentOffsetBytecode; final long[] elements; final int tableIndex; - final byte elemType; + final int elemType; if (bulkMemoryAndRefTypes) { final int sectionType = readUnsignedInt32(); mode = sectionType & 0b001; @@ -2971,7 +2999,7 @@ private void readGlobalSection() { final int startingGlobalIndex = module.symbolTable().numGlobals(); for (int globalIndex = startingGlobalIndex; globalIndex != startingGlobalIndex + globalCount; globalIndex++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); - final byte type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + final int type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); // 0x00 means const, 0x01 means var final byte mutability = readMutability(); // Global initialization expressions must be constant expressions: @@ -3086,29 +3114,123 @@ private void readDataSection(RuntimeBytecodeGen bytecode) { private void readFunctionType() { int paramCount = readLength(); - long resultCountAndValue = peekUnsignedInt32AndLength(data, offset + paramCount); - int resultCount = value(resultCountAndValue); - resultCount = (resultCount == 0x40) ? 0 : resultCount; - module.limits().checkParamCount(paramCount); + int[] paramTypes = new int[paramCount]; + for (int paramIdx = 0; paramIdx < paramCount; paramIdx++) { + paramTypes[paramIdx] = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + } + + int resultCount = readLength(); module.limits().checkResultCount(resultCount, multiValue); - int idx = module.symbolTable().allocateFunctionType(paramCount, resultCount, multiValue); - readParameterList(idx, paramCount); - offset += length(resultCountAndValue); - readResultList(idx, resultCount); + int[] resultTypes = new int[resultCount]; + for (int resultIdx = 0; resultIdx < resultCount; resultIdx++) { + resultTypes[resultIdx] = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + } + + int funcTypeIdx = module.symbolTable().allocateFunctionType(paramCount, resultCount, multiValue); + for (int paramIdx = 0; paramIdx < paramCount; paramIdx++) { + module.symbolTable().registerFunctionTypeParameterType(funcTypeIdx, paramIdx, paramTypes[paramIdx]); + } + for (int resultIdx = 0; resultIdx < resultCount; resultIdx++) { + module.symbolTable().registerFunctionTypeResultType(funcTypeIdx, resultIdx, resultTypes[resultIdx]); + } } - private void readParameterList(int funcTypeIdx, int paramCount) { - for (int paramIdx = 0; paramIdx != paramCount; ++paramIdx) { - byte type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); - module.symbolTable().registerFunctionTypeParameterType(funcTypeIdx, paramIdx, type); + protected int readValueType(boolean allowRefTypes, boolean allowVecType, boolean allowExnType) { + final int type = readSignedInt32(); + switch (type) { + case I32_TYPE, I64_TYPE, F32_TYPE, F64_TYPE -> { + return type; + } + case V128_TYPE -> { + Assert.assertTrue(allowVecType, Failure.MALFORMED_VALUE_TYPE); + return type; + } + case FUNCREF_TYPE, EXTERNREF_TYPE -> { + Assert.assertTrue(allowRefTypes, Failure.MALFORMED_VALUE_TYPE); + return type; + } + case EXNREF_TYPE -> { + Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); + return type; + } + case REF_NULL_TYPE_HEADER, REF_TYPE_HEADER -> { + boolean nullable = type == REF_NULL_TYPE_HEADER; + int heapType = readSignedInt32(); + return switch (heapType) { + case FUNC_HEAPTYPE, EXTERN_HEAPTYPE -> WasmType.withNullable(nullable, heapType); + case EXN_HEAPTYPE -> { + Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); + yield WasmType.withNullable(nullable, heapType); + } + default -> { + if (heapType < 0 || heapType > WasmType.MAX_TYPE_INDEX) { + throw fail(Failure.MALFORMED_REFERENCE_TYPE, "Invalid heap type"); + } + yield WasmType.withNullable(nullable, heapType); + } + }; + } + default -> throw Assert.fail(Failure.MALFORMED_VALUE_TYPE, "Invalid value type: 0x%02X", type); } } - private void readResultList(int funcTypeIdx, int resultCount) { - for (int resultIdx = 0; resultIdx != resultCount; resultIdx++) { - byte type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); - module.symbolTable().registerFunctionTypeResultType(funcTypeIdx, resultIdx, type); + /** + * Reads the block type at the current location. The result is provided as two values. The first + * is the actual value of the block type. The second is an indicator if it is a single result + * type or a multi-value result. + * + * @param result The array used for returning the result. + * + */ + protected void readBlockType(int[] result, boolean allowRefTypes, boolean allowVecType, boolean allowExnType) { + int type = readSignedInt32(); + switch (type) { + case VOID_BLOCK_TYPE -> { + result[1] = BLOCK_TYPE_VOID; + } + case I32_TYPE, I64_TYPE, F32_TYPE, F64_TYPE -> { + result[0] = type; + result[1] = BLOCK_TYPE_VALTYPE; + } + case V128_TYPE -> { + Assert.assertTrue(allowVecType, Failure.MALFORMED_VALUE_TYPE); + result[0] = type; + result[1] = BLOCK_TYPE_VALTYPE; + } + case FUNCREF_TYPE, EXTERNREF_TYPE -> { + Assert.assertTrue(allowRefTypes, Failure.MALFORMED_VALUE_TYPE); + result[0] = type; + result[1] = BLOCK_TYPE_VALTYPE; + } + case EXNREF_TYPE -> { + Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); + result[0] = type; + result[1] = BLOCK_TYPE_VALTYPE; + } + case REF_NULL_TYPE_HEADER, REF_TYPE_HEADER -> { + boolean nullable = type == REF_NULL_TYPE_HEADER; + int heapType = readSignedInt32(); + result[0] = switch (heapType) { + case FUNC_HEAPTYPE, EXTERN_HEAPTYPE -> WasmType.withNullable(nullable, heapType); + case EXN_HEAPTYPE -> { + Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); + yield WasmType.withNullable(nullable, heapType); + } + default -> { + if (heapType < 0 || heapType > WasmType.MAX_TYPE_INDEX) { + throw fail(Failure.MALFORMED_REFERENCE_TYPE, "Invalid heap type"); + } + yield WasmType.withNullable(nullable, heapType); + } + }; + result[1] = BLOCK_TYPE_VALTYPE; + } + default -> { + result[0] = type; + Assert.assertIntGreaterOrEqual(result[0], 0, Failure.MALFORMED_VALUE_TYPE); + result[1] = BLOCK_TYPE_TYPE_INDEX; + } } } @@ -3198,20 +3320,35 @@ private byte readImportType() { return read1(); } - private byte readRefType(boolean allowExnType) { - final byte refType = read1(); + private int readRefType(boolean allowExnType) { + final int refType = readSignedInt32(); switch (refType) { - case FUNCREF_TYPE: - case EXTERNREF_TYPE: - break; - case EXNREF_TYPE: + case FUNCREF_TYPE, EXTERNREF_TYPE -> { + return refType; + } + case EXNREF_TYPE -> { assertTrue(allowExnType, Failure.MALFORMED_REFERENCE_TYPE); - break; - default: - fail(Failure.MALFORMED_REFERENCE_TYPE, "Unexpected reference type"); - break; + return refType; + } + case REF_NULL_TYPE_HEADER, REF_TYPE_HEADER -> { + boolean nullable = refType == REF_NULL_TYPE_HEADER; + int heapType = readSignedInt32(); + return switch (heapType) { + case FUNC_HEAPTYPE, EXTERN_HEAPTYPE -> WasmType.withNullable(nullable, heapType); + case EXN_HEAPTYPE -> { + assertTrue(allowExnType, Failure.MALFORMED_REFERENCE_TYPE); + yield WasmType.withNullable(nullable, heapType); + } + default -> { + if (heapType < 0 || heapType > WasmType.MAX_TYPE_INDEX) { + throw fail(Failure.MALFORMED_REFERENCE_TYPE, "Unexpected reference type"); + } + yield WasmType.withNullable(nullable, heapType); + } + }; + } + default -> throw fail(Failure.MALFORMED_REFERENCE_TYPE, "Unexpected reference type"); } - return refType; } private void readTableLimits(int[] out) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryStreamParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryStreamParser.java index 5ba793ba08e2..4ca161b0c6ea 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryStreamParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryStreamParser.java @@ -57,8 +57,9 @@ import com.oracle.truffle.api.nodes.ExplodeLoop; public abstract class BinaryStreamParser { - protected static final int SINGLE_RESULT_VALUE = 0; - protected static final int MULTI_RESULT_VALUE = 1; + protected static final int BLOCK_TYPE_VOID = 0; + protected static final int BLOCK_TYPE_VALTYPE = 1; + protected static final int BLOCK_TYPE_TYPE_INDEX = 2; private static final VarHandle I16LE = MethodHandles.byteArrayViewVarHandle(short[].class, ByteOrder.LITTLE_ENDIAN); private static final VarHandle I32LE = MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.LITTLE_ENDIAN); @@ -294,84 +295,6 @@ protected int offset() { return offset; } - /** - * Reads the block type at the current location. The result is provided as two values. The first - * is the actual value of the block type. The second is an indicator if it is a single result - * type or a multi-value result. - * - * @param result The array used for returning the result. - * - */ - protected void readBlockType(int[] result, boolean allowRefTypes, boolean allowVecType, boolean allowExnType) { - byte type = peek1(data, offset); - switch (type) { - case WasmType.VOID_TYPE: - case WasmType.I32_TYPE: - case WasmType.I64_TYPE: - case WasmType.F32_TYPE: - case WasmType.F64_TYPE: - offset++; - result[0] = type; - result[1] = SINGLE_RESULT_VALUE; - break; - case WasmType.V128_TYPE: - Assert.assertTrue(allowVecType, Failure.MALFORMED_VALUE_TYPE); - offset++; - result[0] = type; - result[1] = SINGLE_RESULT_VALUE; - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - Assert.assertTrue(allowRefTypes, Failure.MALFORMED_VALUE_TYPE); - offset++; - result[0] = type; - result[1] = SINGLE_RESULT_VALUE; - break; - case WasmType.EXNREF_TYPE: - Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); - offset++; - result[0] = type; - result[1] = SINGLE_RESULT_VALUE; - break; - default: - long valueAndLength = peekSignedInt32AndLength(data, offset); - result[0] = value(valueAndLength); - Assert.assertIntGreaterOrEqual(result[0], 0, Failure.UNSPECIFIED_MALFORMED); - result[1] = MULTI_RESULT_VALUE; - offset += length(valueAndLength); - } - } - - protected static byte peekValueType(byte[] data, int offset, boolean allowRefTypes, boolean allowVecType, boolean allowExnType) { - byte b = peek1(data, offset); - switch (b) { - case WasmType.I32_TYPE: - case WasmType.I64_TYPE: - case WasmType.F32_TYPE: - case WasmType.F64_TYPE: - break; - case WasmType.V128_TYPE: - Assert.assertTrue(allowVecType, Failure.MALFORMED_VALUE_TYPE); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - Assert.assertTrue(allowRefTypes, Failure.MALFORMED_VALUE_TYPE); - break; - case WasmType.EXNREF_TYPE: - Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); - break; - default: - Assert.fail(Failure.MALFORMED_VALUE_TYPE, "Invalid value type: 0x%02X", b); - } - return b; - } - - protected byte readValueType(boolean allowRefTypes, boolean allowVecType, boolean allowExnType) { - byte b = peekValueType(data, offset, allowRefTypes, allowVecType, allowExnType); - offset++; - return b; - } - @ExplodeLoop(kind = FULL_EXPLODE_UNTIL_RETURN) public static byte peekLeb128Length(byte[] data, int initialOffset) { int currentOffset = initialOffset; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/GlobalRegistry.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/GlobalRegistry.java index 47d8a54b2cfd..8bfc6fb9b07a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/GlobalRegistry.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/GlobalRegistry.java @@ -46,7 +46,6 @@ import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.globals.WasmGlobal; -import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; /** @@ -117,15 +116,17 @@ private Object loadAsObject(int address) { return objectGlobals[address]; } - public void store(byte globalValueType, int address, Object value) { + public void store(int globalValueType, int address, Object value) { switch (globalValueType) { case WasmType.I32_TYPE -> storeInt(address, (int) value); case WasmType.I64_TYPE -> storeLong(address, (long) value); case WasmType.F32_TYPE -> storeFloat(address, (float) value); case WasmType.F64_TYPE -> storeDouble(address, (double) value); case WasmType.V128_TYPE -> storeVector128(address, (Vector128) value); - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> storeReference(address, value); - default -> throw CompilerDirectives.shouldNotReachHere(); + default -> { + assert WasmType.isReferenceType(globalValueType); + storeReference(address, value); + } } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 8479286032aa..4c9049a51879 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -40,8 +40,8 @@ */ package org.graalvm.wasm; -import static org.graalvm.wasm.Assert.assertByteEqual; import static org.graalvm.wasm.Assert.assertFunctionTypeEquals; +import static org.graalvm.wasm.Assert.assertIntEqual; import static org.graalvm.wasm.Assert.assertTrue; import static org.graalvm.wasm.Assert.assertUnsignedIntGreaterOrEqual; import static org.graalvm.wasm.Assert.assertUnsignedIntLess; @@ -55,10 +55,8 @@ import static org.graalvm.wasm.BinaryStreamParser.rawPeekI8; import static org.graalvm.wasm.BinaryStreamParser.rawPeekU8; import static org.graalvm.wasm.Linker.ResolutionDag.NO_RESOLVE_ACTION; -import static org.graalvm.wasm.WasmType.EXTERNREF_TYPE; import static org.graalvm.wasm.WasmType.F32_TYPE; import static org.graalvm.wasm.WasmType.F64_TYPE; -import static org.graalvm.wasm.WasmType.FUNCREF_TYPE; import static org.graalvm.wasm.WasmType.I32_TYPE; import static org.graalvm.wasm.WasmType.I64_TYPE; import static org.graalvm.wasm.WasmType.V128_TYPE; @@ -89,7 +87,6 @@ import org.graalvm.wasm.Linker.ResolutionDag.InitializeGlobalSym; import org.graalvm.wasm.Linker.ResolutionDag.Resolver; import org.graalvm.wasm.Linker.ResolutionDag.Sym; -import org.graalvm.wasm.SymbolTable.FunctionType; import org.graalvm.wasm.api.ExecuteHostFunctionNode; import org.graalvm.wasm.api.ValueType; import org.graalvm.wasm.api.Vector128; @@ -196,7 +193,12 @@ private void tryLinkOutsidePartialEvaluation(WasmInstance entryPointInstance, Im ArrayList failures = new ArrayList<>(); final int maxStartFunctionIndex = runLinkActions(store, instances, importValues, failures); linkTopologically(store, failures, maxStartFunctionIndex); - assignTypeEquivalenceClasses(store); + for (WasmInstance instance : instances.values()) { + WasmModule module = instance.module(); + if (instance.isLinkInProgress() && !module.isParsed()) { + module.setParsed(); + } + } resolutionDag = null; runStartFunctions(instances, failures); checkFailures(failures); @@ -243,39 +245,6 @@ private void linkTopologically(WasmStore store, ArrayList failures, i } } - private static void assignTypeEquivalenceClasses(WasmStore store) { - final Map instances = store.moduleInstances(); - for (WasmInstance instance : instances.values()) { - WasmModule module = instance.module(); - if (instance.isLinkInProgress() && !module.isParsed()) { - assignTypeEquivalenceClasses(module, store.language()); - } - } - } - - private static void assignTypeEquivalenceClasses(WasmModule module, WasmLanguage language) { - var lock = module.getLock(); - lock.lock(); - try { - if (module.isParsed()) { - return; - } - final SymbolTable symtab = module.symbolTable(); - for (int index = 0; index < symtab.typeCount(); index++) { - FunctionType type = symtab.typeAt(index); - int equivalenceClass = language.equivalenceClassFor(type); - symtab.setEquivalenceClass(index, equivalenceClass); - } - for (int index = 0; index < symtab.numFunctions(); index++) { - final WasmFunction function = symtab.function(index); - function.setTypeEquivalenceClass(symtab.equivalenceClass(function.typeIndex())); - } - module.setParsed(); - } finally { - lock.unlock(); - } - } - private static void runStartFunctions(Map instances, ArrayList failures) { List instanceList = new ArrayList<>(instances.values()); instanceList.sort(Comparator.comparingInt(RuntimeState::startFunctionIndex)); @@ -332,7 +301,7 @@ private static void checkFailures(ArrayList failures) { } } - void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int globalIndex, byte valueType, byte mutability, + void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int globalIndex, int valueType, byte mutability, ImportValueSupplier imports) { instance.globals().setInitialized(globalIndex, false); final String importedGlobalName = importDescriptor.memberName(); @@ -340,10 +309,10 @@ void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescripto final Runnable resolveAction = () -> { assert instance.module().globalImported(globalIndex) && globalIndex == importDescriptor.targetIndex() : importDescriptor; WasmGlobal externalGlobal = lookupImportObject(instance, importDescriptor, imports, WasmGlobal.class); - final byte exportedValueType; + final int exportedValueType; final byte exportedMutability; if (externalGlobal != null) { - exportedValueType = externalGlobal.getValueType().byteValue(); + exportedValueType = externalGlobal.getValueType().value(); exportedMutability = externalGlobal.getMutability(); } else { final WasmInstance importedInstance = store.lookupModuleInstance(importedModuleName); @@ -393,7 +362,7 @@ private static void initializeGlobal(WasmInstance instance, int globalIndex, Obj assert !instance.globals().isInitialized(globalIndex) : globalIndex; SymbolTable symbolTable = instance.symbolTable(); if (symbolTable.globalExternal(globalIndex)) { - var global = new WasmGlobal(ValueType.fromByteValue(symbolTable.globalValueType(globalIndex)), symbolTable.isGlobalMutable(globalIndex), initValue); + var global = new WasmGlobal(ValueType.fromValue(symbolTable.globalValueType(globalIndex)), symbolTable.isGlobalMutable(globalIndex), initValue); instance.setExternalGlobal(globalIndex, global); } else { instance.globals().store(symbolTable.globalValueType(globalIndex), symbolTable.globalAddress(globalIndex), initValue); @@ -586,7 +555,7 @@ void resolveTagExport(WasmInstance instance, int tagIndex, String exportedTagNam private static Object lookupGlobal(WasmInstance instance, int index) { final SymbolTable symbolTable = instance.symbolTable(); - final byte type = symbolTable.globalValueType(index); + final int type = symbolTable.globalValueType(index); final int globalAddress = symbolTable.globalAddress(index); final GlobalRegistry globals = instance.globals(); if (!globals.isInitialized(index)) { @@ -598,8 +567,10 @@ private static Object lookupGlobal(WasmInstance instance, int index) { case I64_TYPE -> globals.loadAsLong(globalAddress); case F64_TYPE -> globals.loadAsDouble(globalAddress); case V128_TYPE -> globals.loadAsVector128(globalAddress); - case FUNCREF_TYPE, EXTERNREF_TYPE -> globals.loadAsReference(globalAddress); - default -> throw WasmException.create(Failure.UNSPECIFIED_TRAP, "Global variable cannot have the void type."); + default -> { + assert WasmType.isReferenceType(type); + yield globals.loadAsReference(globalAddress); + } }; } @@ -820,7 +791,7 @@ void resolvePassiveDataSegment(WasmStore store, WasmInstance instance, int dataS resolutionDag.resolveLater(new DataSym(instance.name(), dataSegmentId), dependencies.toArray(new Sym[0]), resolveAction); } - void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int tableIndex, int declaredMinSize, int declaredMaxSize, byte elemType, + void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int tableIndex, int declaredMinSize, int declaredMaxSize, int elemType, ImportValueSupplier imports) { final Runnable resolveAction = () -> { WasmTable externalTable = lookupImportObject(instance, importDescriptor, imports, WasmTable.class); @@ -856,7 +827,7 @@ void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor // MAX_TABLE_DECLARATION_SIZE, so this condition will pass. assertUnsignedIntLessOrEqual(declaredMinSize, importedTable.minSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); assertUnsignedIntGreaterOrEqual(declaredMaxSize, importedTable.declaredMaxSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); - assertByteEqual(elemType, importedTable.elemType(), Failure.INCOMPATIBLE_IMPORT_TYPE); + assertIntEqual(elemType, importedTable.elemType(), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTableAddress(tableIndex, tableAddress); }; Sym[] dependencies = new Sym[]{new ExportTableSym(importDescriptor.moduleName(), importDescriptor.memberName())}; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/ModuleLimits.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/ModuleLimits.java index 4a2e739e6084..8b32814a25ec 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/ModuleLimits.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/ModuleLimits.java @@ -80,7 +80,7 @@ public ModuleLimits(int moduleSizeLimit, int typeCountLimit, int functionCountLi int dataSegmentCountLimit, int elementSegmentCountLimit, int functionSizeLimit, int paramCountLimit, int resultCountLimit, int localCountLimit, int tableInstanceSizeLimit, int memoryInstanceSizeLimit, long memory64InstanceSizeLimit, int tagCountLimit) { this.moduleSizeLimit = minUnsigned(moduleSizeLimit, Integer.MAX_VALUE); - this.typeCountLimit = minUnsigned(typeCountLimit, Integer.MAX_VALUE); + this.typeCountLimit = minUnsigned(typeCountLimit, WasmType.MAX_TYPE_INDEX); this.functionCountLimit = minUnsigned(functionCountLimit, Integer.MAX_VALUE); this.tableCountLimit = minUnsigned(tableCountLimit, Integer.MAX_VALUE); this.multiMemoryCountLimit = minUnsigned(memoryCountLimit, Integer.MAX_VALUE); @@ -109,7 +109,7 @@ private static long minUnsigned(long a, long b) { static final ModuleLimits DEFAULTS = new ModuleLimits( Integer.MAX_VALUE, - Integer.MAX_VALUE, + WasmType.MAX_TYPE_INDEX, Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE, diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index 1119e5483b42..e2a6a9bbf3b6 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -40,7 +40,6 @@ */ package org.graalvm.wasm; -import static org.graalvm.wasm.Assert.assertByteEqual; import static org.graalvm.wasm.Assert.assertIntEqual; import static org.graalvm.wasm.Assert.assertTrue; import static org.graalvm.wasm.Assert.assertUnsignedIntLess; @@ -84,32 +83,161 @@ public abstract class SymbolTable { private static final byte GLOBAL_FUNCTION_INITIALIZER_BIT = 0x20; public static final int UNINITIALIZED_ADDRESS = Integer.MIN_VALUE; - private static final int NO_EQUIVALENCE_CLASS = 0; - static final int FIRST_EQUIVALENCE_CLASS = NO_EQUIVALENCE_CLASS + 1; + + public abstract static sealed class ClosedValueType { + public abstract boolean matches(ClosedValueType valueSubType); + } + + public abstract static sealed class ClosedHeapType { + public abstract boolean matches(ClosedHeapType heapSubType); + } + + public static final class NumberType extends ClosedValueType { + public static final NumberType I32 = new NumberType(WasmType.I32_TYPE); + public static final NumberType I64 = new NumberType(WasmType.I64_TYPE); + public static final NumberType F32 = new NumberType(WasmType.F32_TYPE); + public static final NumberType F64 = new NumberType(WasmType.F64_TYPE); + + private final int value; + + private NumberType(int value) { + this.value = value; + } + + @Override + public boolean matches(ClosedValueType valueSubType) { + return this == valueSubType; + } + } + + public static final class VectorType extends ClosedValueType { + public static final VectorType V128 = new VectorType(WasmType.V128_TYPE); + + private final int value; + + private VectorType(int value) { + this.value = value; + } + + @Override + public boolean matches(ClosedValueType valueSubType) { + return this == valueSubType; + } + } + + public static final class ClosedReferenceType extends ClosedValueType { + public static final ClosedReferenceType FUNCREF = new ClosedReferenceType(true, AbstractHeapType.FUNC); + public static final ClosedReferenceType NONNULL_FUNCREF = new ClosedReferenceType(false, AbstractHeapType.FUNC); + public static final ClosedReferenceType EXTERNREF = new ClosedReferenceType(true, AbstractHeapType.EXTERN); + public static final ClosedReferenceType NONNULL_EXTERNREF = new ClosedReferenceType(false, AbstractHeapType.EXTERN); + + private final boolean nullable; + private final ClosedHeapType closedHeapType; + + public ClosedReferenceType(boolean nullable, ClosedHeapType closedHeapType) { + this.nullable = nullable; + this.closedHeapType = closedHeapType; + } + + @Override + public boolean matches(ClosedValueType valueSubType) { + return valueSubType instanceof ClosedReferenceType referenceSubType && (!referenceSubType.nullable || this.nullable) && this.closedHeapType.matches(referenceSubType.closedHeapType); + } + } + + public static final class AbstractHeapType extends ClosedHeapType { + public static final AbstractHeapType FUNC = new AbstractHeapType(WasmType.FUNC_HEAPTYPE); + public static final AbstractHeapType EXTERN = new AbstractHeapType(WasmType.EXTERN_HEAPTYPE); + + private final int value; + + private AbstractHeapType(int value) { + this.value = value; + } + + @Override + public boolean matches(ClosedHeapType heapSubType) { + return switch (this.value) { + case WasmType.FUNC_HEAPTYPE -> this == FUNC || heapSubType instanceof ClosedFunctionType; + case WasmType.EXTERN_HEAPTYPE -> this == EXTERN; + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } + } + + public static final class ClosedFunctionType extends ClosedHeapType { + @CompilationFinal(dimensions = 1) private final ClosedValueType[] paramTypes; + @CompilationFinal(dimensions = 1) private final ClosedValueType[] resultTypes; + + ClosedFunctionType(ClosedValueType[] paramTypes, ClosedValueType[] resultTypes) { + this.paramTypes = paramTypes; + this.resultTypes = resultTypes; + } + + @Override + public boolean matches(ClosedHeapType heapSubType) { + if (!(heapSubType instanceof ClosedFunctionType functionSubType)) { + return false; + } + if (this.paramTypes.length != functionSubType.paramTypes.length) { + return false; + } + for (int i = 0; i < this.paramTypes.length; i++) { + if (!functionSubType.paramTypes[i].matches(this.paramTypes[i])) { + return false; + } + } + if (this.resultTypes.length != functionSubType.resultTypes.length) { + return false; + } + for (int i = 0; i < this.resultTypes.length; i++) { + if (!this.resultTypes[i].matches(functionSubType.resultTypes[i])) { + return false; + } + } + return true; + } + } public static final class FunctionType { - @CompilationFinal(dimensions = 1) private final byte[] paramTypes; - @CompilationFinal(dimensions = 1) private final byte[] resultTypes; + @CompilationFinal(dimensions = 1) private final int[] paramTypes; + @CompilationFinal(dimensions = 1) private final int[] resultTypes; + private final SymbolTable symbolTable; private final int hashCode; - FunctionType(byte[] paramTypes, byte[] resultTypes) { + FunctionType(int[] paramTypes, int[] resultTypes, SymbolTable symbolTable) { this.paramTypes = paramTypes; this.resultTypes = resultTypes; - this.hashCode = Arrays.hashCode(paramTypes) ^ Arrays.hashCode(resultTypes); + this.symbolTable = symbolTable; + this.hashCode = Arrays.hashCode(paramTypes) ^ Arrays.hashCode(resultTypes) ^ Arrays.hashCode(symbolTable.typeData); } - public static FunctionType create(byte[] paramTypes, byte[] resultTypes) { - return new FunctionType(paramTypes, resultTypes); + public static FunctionType createClosed(int[] paramTypes, int[] resultTypes) { + assert allClosed(paramTypes) && allClosed(resultTypes); + return new FunctionType(paramTypes, resultTypes, null); } - public byte[] paramTypes() { + private static boolean allClosed(int[] types) { + for (int type : types) { + if (WasmType.isConcreteReferenceType(type)) { + return false; + } + } + return true; + } + + public int[] paramTypes() { return paramTypes; } - public byte[] resultTypes() { + public int[] resultTypes() { return resultTypes; } + public SymbolTable symbolTable() { + return symbolTable; + } + @Override public int hashCode() { return hashCode; @@ -162,7 +290,7 @@ public String toString() { * might have a lower internal max allowed size in practice. * @param elemType The element type of the table. */ - public record TableInfo(int initialSize, int maximumSize, byte elemType) { + public record TableInfo(int initialSize, int maximumSize, int elemType) { } /** @@ -214,16 +342,6 @@ public record TagInfo(byte attribute, int typeIndex) { */ @CompilationFinal(dimensions = 1) private int[] typeOffsets; - /** - * Stores the type equivalence class. - *

- * Since multiple types have the same shape, each type is mapped to an equivalence class, so - * that two types can be quickly compared. - *

- * The equivalence classes are computed globally for all the modules, during linking. - */ - @CompilationFinal(dimensions = 1) private int[] typeEquivalenceClasses; - @CompilationFinal private int typeDataSize; @CompilationFinal private int typeCount; @@ -269,14 +387,18 @@ public record TagInfo(byte attribute, int typeIndex) { @CompilationFinal private int startFunctionIndex; /** - * A global type is the value type of the global, followed by its mutability. This is encoded as - * two bytes -- the lowest (0th) byte is the value type. The 1st byte is organized like this: + * Value types of globals. + */ + @CompilationFinal(dimensions = 1) private int[] globalTypes; + + /** + * Mutability flags of globals. These are encoded like this: *

* * | . | . | . | functionOrIndex flag | reference flag | initialized flag | exported flag | mutable flag | * */ - @CompilationFinal(dimensions = 1) private byte[] globalTypes; + @CompilationFinal(dimensions = 1) private byte[] globalFlags; /** * The values or indices used for initializing globals. @@ -426,7 +548,6 @@ public record TagInfo(byte attribute, int typeIndex) { CompilerAsserts.neverPartOfCompilation(); this.typeData = new int[INITIAL_DATA_SIZE]; this.typeOffsets = new int[INITIAL_TYPE_SIZE]; - this.typeEquivalenceClasses = new int[INITIAL_TYPE_SIZE]; this.typeDataSize = 0; this.typeCount = 0; this.importedSymbols = new ArrayList<>(); @@ -438,7 +559,8 @@ public record TagInfo(byte attribute, int typeIndex) { this.exportedFunctions = EconomicMap.create(); this.exportedFunctionsByIndex = EconomicMap.create(); this.startFunctionIndex = -1; - this.globalTypes = new byte[2 * INITIAL_GLOBALS_SIZE]; + this.globalTypes = new int[INITIAL_GLOBALS_SIZE]; + this.globalFlags = new byte[INITIAL_GLOBALS_SIZE]; this.globalInitializers = new Object[INITIAL_GLOBALS_SIZE]; this.globalInitializersBytecode = new byte[INITIAL_GLOBALS_BYTECODE_SIZE][]; this.importedGlobals = EconomicMap.create(); @@ -509,9 +631,9 @@ private void ensureTypeDataCapacity(int index) { } /** - * Ensure that the {@link #typeOffsets} and {@link #typeEquivalenceClasses} arrays have enough - * space to store the data for the type at {@code index}. If there is not enough space, then a - * reallocation of the array takes place, doubling its capacity. + * Ensure that the {@link #typeOffsets} array has enough space to store the data for the type at + * {@code index}. If there is not enough space, then a reallocation of the array takes place, + * doubling its capacity. *

* No synchronisation is required for this method, as it is only called during parsing, which is * carried out by a single thread. @@ -520,7 +642,6 @@ private void ensureTypeCapacity(int index) { if (typeOffsets.length <= index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * typeOffsets.length); typeOffsets = reallocate(typeOffsets, typeCount, newLength); - typeEquivalenceClasses = reallocate(typeEquivalenceClasses, typeCount, newLength); } } @@ -542,7 +663,7 @@ int allocateFunctionType(int paramCount, int resultCount, boolean isMultiValue) return typeIdx; } - public int allocateFunctionType(byte[] paramTypes, byte[] resultTypes, boolean isMultiValue) { + public int allocateFunctionType(int[] paramTypes, int[] resultTypes, boolean isMultiValue) { checkNotParsed(); final int typeIdx = allocateFunctionType(paramTypes.length, resultTypes.length, isMultiValue); for (int i = 0; i < paramTypes.length; i++) { @@ -554,30 +675,18 @@ public int allocateFunctionType(byte[] paramTypes, byte[] resultTypes, boolean i return typeIdx; } - void registerFunctionTypeParameterType(int funcTypeIdx, int paramIdx, byte type) { + void registerFunctionTypeParameterType(int funcTypeIdx, int paramIdx, int type) { checkNotParsed(); int idx = 2 + typeOffsets[funcTypeIdx] + paramIdx; typeData[idx] = type; } - void registerFunctionTypeResultType(int funcTypeIdx, int resultIdx, byte type) { + void registerFunctionTypeResultType(int funcTypeIdx, int resultIdx, int type) { checkNotParsed(); int idx = 2 + typeOffsets[funcTypeIdx] + typeData[typeOffsets[funcTypeIdx]] + resultIdx; typeData[idx] = type; } - public int equivalenceClass(int typeIndex) { - return typeEquivalenceClasses[typeIndex]; - } - - void setEquivalenceClass(int index, int eqClass) { - checkNotParsed(); - if (typeEquivalenceClasses[index] != NO_EQUIVALENCE_CLASS) { - throw WasmException.create(Failure.UNSPECIFIED_INVALID, "Type at index " + index + " already has an equivalence class."); - } - typeEquivalenceClasses[index] = eqClass; - } - private void ensureFunctionsCapacity(int index) { if (functions.length <= index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * functions.length); @@ -655,29 +764,29 @@ public WasmFunction startFunction() { protected abstract WasmModule module(); - public byte functionTypeParamTypeAt(int typeIndex, int i) { + public int functionTypeParamTypeAt(int typeIndex, int paramIndex) { int typeOffset = typeOffsets[typeIndex]; - return (byte) typeData[typeOffset + 2 + i]; + return typeData[typeOffset + 2 + paramIndex]; } - public byte functionTypeResultTypeAt(int typeIndex, int resultIndex) { + public int functionTypeResultTypeAt(int typeIndex, int resultIndex) { int typeOffset = typeOffsets[typeIndex]; int paramCount = typeData[typeOffset]; - return (byte) typeData[typeOffset + 2 + paramCount + resultIndex]; + return typeData[typeOffset + 2 + paramCount + resultIndex]; } - private byte[] functionTypeParamTypesAsArray(int typeIndex) { + private int[] functionTypeParamTypesAsArray(int typeIndex) { int paramCount = functionTypeParamCount(typeIndex); - byte[] paramTypes = new byte[paramCount]; + int[] paramTypes = new int[paramCount]; for (int i = 0; i < paramCount; ++i) { paramTypes[i] = functionTypeParamTypeAt(typeIndex, i); } return paramTypes; } - private byte[] functionTypeResultTypesAsArray(int typeIndex) { + private int[] functionTypeResultTypesAsArray(int typeIndex) { int resultTypeCount = functionTypeResultCount(typeIndex); - byte[] resultTypes = new byte[resultTypeCount]; + int[] resultTypes = new int[resultTypeCount]; for (int i = 0; i < resultTypeCount; i++) { resultTypes[i] = functionTypeResultTypeAt(typeIndex, i); } @@ -689,7 +798,45 @@ int typeCount() { } public FunctionType typeAt(int index) { - return new FunctionType(functionTypeParamTypesAsArray(index), functionTypeResultTypesAsArray(index)); + return new FunctionType(functionTypeParamTypesAsArray(index), functionTypeResultTypesAsArray(index), this); + } + + public ClosedFunctionType closedFunctionTypeAt(int typeIndex) { + ClosedValueType[] paramTypes = new ClosedValueType[functionTypeParamCount(typeIndex)]; + for (int i = 0; i < paramTypes.length; i++) { + paramTypes[i] = closedTypeAt(functionTypeParamTypeAt(typeIndex, i)); + } + ClosedValueType[] resultTypes = new ClosedValueType[functionTypeResultCount(typeIndex)]; + for (int i = 0; i < resultTypes.length; i++) { + resultTypes[i] = closedTypeAt(functionTypeResultTypeAt(typeIndex, i)); + } + return new ClosedFunctionType(paramTypes, resultTypes); + } + + public ClosedValueType closedTypeAt(int type) { + return switch (type) { + case WasmType.I32_TYPE -> NumberType.I32; + case WasmType.I64_TYPE -> NumberType.I64; + case WasmType.F32_TYPE -> NumberType.F32; + case WasmType.F64_TYPE -> NumberType.F64; + case WasmType.V128_TYPE -> VectorType.V128; + default -> { + assert WasmType.isReferenceType(type); + boolean nullable = WasmType.isNullable(type); + yield switch (WasmType.getAbstractHeapType(type)) { + case WasmType.FUNC_HEAPTYPE -> nullable ? ClosedReferenceType.FUNCREF : ClosedReferenceType.NONNULL_FUNCREF; + case WasmType.EXTERN_HEAPTYPE -> nullable ? ClosedReferenceType.EXTERNREF : ClosedReferenceType.NONNULL_EXTERNREF; + default -> { + assert WasmType.isConcreteReferenceType(type); + yield new ClosedReferenceType(nullable, closedFunctionTypeAt(WasmType.getTypeIndex(type))); + } + }; + } + }; + } + + public boolean matches(int superType, int subType) { + return closedTypeAt(superType).matches(closedTypeAt(subType)); } public void importSymbol(ImportDescriptor descriptor) { @@ -758,11 +905,14 @@ public WasmFunction importedFunction(ImportDescriptor descriptor) { private void ensureGlobalsCapacity(int index) { while (index >= globalInitializers.length) { - final byte[] nGlobalTypes = new byte[globalTypes.length * 2]; + final int[] nGlobalTypes = new int[globalTypes.length * 2]; + final byte[] nGlobalFlags = new byte[globalFlags.length * 2]; final Object[] nGlobalInitializers = new Object[globalInitializers.length * 2]; System.arraycopy(globalTypes, 0, nGlobalTypes, 0, globalTypes.length); + System.arraycopy(globalFlags, 0, nGlobalFlags, 0, globalFlags.length); System.arraycopy(globalInitializers, 0, nGlobalInitializers, 0, globalInitializers.length); globalTypes = nGlobalTypes; + globalFlags = nGlobalFlags; globalInitializers = nGlobalInitializers; } } @@ -779,8 +929,7 @@ private void ensureGlobalInitializersBytecodeCapacity(int index) { * Allocates a global index in the symbol table, for a global variable that was already * allocated. */ - void allocateGlobal(int index, byte valueType, byte mutability, boolean initialized, boolean imported, byte[] initBytecode, Object initialValue) { - assert (valueType & 0xff) == valueType; + void allocateGlobal(int index, int valueType, byte mutability, boolean initialized, boolean imported, byte[] initBytecode, Object initialValue) { checkNotParsed(); ensureGlobalsCapacity(index); numGlobals = maxUnsigned(index + 1, numGlobals); @@ -808,8 +957,8 @@ void allocateGlobal(int index, byte valueType, byte mutability, boolean initiali globalInitializersBytecode[initBytecodeIndex] = initBytecode; globalInitializers[index] = initBytecodeIndex; } - globalTypes[2 * index] = valueType; - globalTypes[2 * index + 1] = flags; + globalTypes[index] = valueType; + globalFlags[index] = flags; } /** @@ -817,7 +966,7 @@ void allocateGlobal(int index, byte valueType, byte mutability, boolean initiali * default, but may be exported using {@link #exportGlobal}. Imported globals are declared using * {@link #importGlobal} instead. This method may only be called during parsing, before linking. */ - void declareGlobal(int index, byte valueType, byte mutability, boolean initialized, byte[] initBytecode, Object initialValue) { + void declareGlobal(int index, int valueType, byte mutability, boolean initialized, byte[] initBytecode, Object initialValue) { assert initialized == (initBytecode == null) : index; allocateGlobal(index, valueType, mutability, initialized, false, initBytecode, initialValue); module().addLinkAction((context, store, instance, imports) -> { @@ -828,7 +977,7 @@ void declareGlobal(int index, byte valueType, byte mutability, boolean initializ /** * Declares an imported global. May be re-exported. */ - void importGlobal(String moduleName, String globalName, int index, byte valueType, byte mutability) { + void importGlobal(String moduleName, String globalName, int index, int valueType, byte mutability) { final ImportDescriptor descriptor = new ImportDescriptor(moduleName, globalName, ImportIdentifier.GLOBAL, index, numImportedSymbols()); importedGlobals.put(index, descriptor); importSymbol(descriptor); @@ -879,12 +1028,12 @@ public boolean isGlobalMutable(int index) { return globalMutability(index) == GlobalModifier.MUTABLE; } - public byte globalValueType(int index) { - return globalTypes[2 * index]; + public int globalValueType(int index) { + return globalTypes[index]; } private byte globalFlags(int index) { - return globalTypes[2 * index + 1]; + return globalFlags[index]; } public boolean globalInitialized(int index) { @@ -929,14 +1078,14 @@ void exportGlobal(String name, int index) { if (!globalExternal(index)) { numExternalGlobals++; } - globalTypes[2 * index + 1] |= GLOBAL_EXPORTED_BIT; + globalFlags[index] |= GLOBAL_EXPORTED_BIT; exportedGlobals.put(name, index); module().addLinkAction((context, store, instance, imports) -> { store.linker().resolveGlobalExport(instance.module(), name, index); }); } - public void declareExportedGlobalWithValue(String name, int index, byte valueType, byte mutability, Object value) { + public void declareExportedGlobalWithValue(String name, int index, int valueType, byte mutability, Object value) { checkNotParsed(); declareGlobal(index, valueType, mutability, true, null, value); exportGlobal(name, index); @@ -950,7 +1099,7 @@ private void ensureTableCapacity(int index) { } } - public void allocateTable(int index, int declaredMinSize, int declaredMaxSize, byte elemType, boolean referenceTypes) { + public void allocateTable(int index, int declaredMinSize, int declaredMaxSize, int elemType, boolean referenceTypes) { checkNotParsed(); addTable(index, declaredMinSize, declaredMaxSize, elemType, referenceTypes); module().addLinkAction((context, store, instance, imports) -> { @@ -968,7 +1117,7 @@ public void allocateTable(int index, int declaredMinSize, int declaredMaxSize, b }); } - void importTable(String moduleName, String tableName, int index, int initSize, int maxSize, byte elemType, boolean referenceTypes) { + void importTable(String moduleName, String tableName, int index, int initSize, int maxSize, int elemType, boolean referenceTypes) { checkNotParsed(); addTable(index, initSize, maxSize, elemType, referenceTypes); final ImportDescriptor importedTable = new ImportDescriptor(moduleName, tableName, ImportIdentifier.TABLE, index, numImportedSymbols()); @@ -980,7 +1129,7 @@ void importTable(String moduleName, String tableName, int index, int initSize, i }); } - void addTable(int index, int minSize, int maxSize, byte elemType, boolean referenceTypes) { + void addTable(int index, int minSize, int maxSize, int elemType, boolean referenceTypes) { if (!referenceTypes) { assertTrue(importedTables.isEmpty(), "A table has already been imported in the module.", Failure.MULTIPLE_TABLES); assertTrue(tableCount == 0, "A table has already been declared in the module.", Failure.MULTIPLE_TABLES); @@ -1040,7 +1189,7 @@ public int tableMaximumSize(int index) { return table.maximumSize; } - public byte tableElementType(int index) { + public int tableElementType(int index) { final TableInfo table = tables[index]; assert table != null; return table.elemType; @@ -1314,8 +1463,8 @@ public void checkElemIndex(int elemIndex) { assertUnsignedIntLess(elemIndex, elemSegmentCount, Failure.UNKNOWN_ELEM_SEGMENT); } - public void checkElemType(int elemIndex, byte expectedType) { - assertByteEqual(expectedType, (byte) elemInstances[elemIndex], Failure.TYPE_MISMATCH); + public void checkElemType(int elemIndex, int expectedType) { + assertIntEqual(expectedType, (int) elemInstances[elemIndex], Failure.TYPE_MISMATCH); } private void ensureElemInstanceCapacity(int index) { @@ -1328,9 +1477,9 @@ private void ensureElemInstanceCapacity(int index) { } } - void setElemInstance(int index, int offset, byte elemType) { + void setElemInstance(int index, int offset, int elemType) { ensureElemInstanceCapacity(index); - elemInstances[index] = (long) offset << 32 | (elemType & 0xFF); + elemInstances[index] = (long) offset << 32 | (elemType & 0xFFFF_FFFFL); elemSegmentCount++; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java index 6dd87c4a6a35..645352e95e37 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java @@ -47,15 +47,15 @@ public final class WasmCodeEntry { private final WasmFunction function; @CompilationFinal(dimensions = 1) private final byte[] bytecode; - @CompilationFinal(dimensions = 1) private final byte[] localTypes; - @CompilationFinal(dimensions = 1) private final byte[] resultTypes; + @CompilationFinal(dimensions = 1) private final int[] localTypes; + @CompilationFinal(dimensions = 1) private final int[] resultTypes; private final BranchProfile errorBranch = BranchProfile.create(); private final BranchProfile exceptionBranch = BranchProfile.create(); private final int numLocals; private final int resultCount; private final boolean usesMemoryZero; - public WasmCodeEntry(WasmFunction function, byte[] bytecode, byte[] localTypes, byte[] resultTypes, boolean usesMemoryZero) { + public WasmCodeEntry(WasmFunction function, byte[] bytecode, int[] localTypes, int[] resultTypes, boolean usesMemoryZero) { this.function = function; this.bytecode = bytecode; this.localTypes = localTypes; @@ -73,7 +73,7 @@ public byte[] bytecode() { return bytecode; } - public byte localType(int index) { + public int localType(int index) { return localTypes[index]; } @@ -89,7 +89,7 @@ public int resultCount() { return resultCount; } - public byte resultType(int index) { + public int resultType(int index) { return resultTypes[index]; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java index 4a96698b673e..20bca810317d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java @@ -49,7 +49,6 @@ public final class WasmFunction { private final int index; private final ImportDescriptor importDescriptor; private final int typeIndex; - @CompilationFinal private int typeEquivalenceClass; @CompilationFinal private String debugName; @CompilationFinal private CallTarget callTarget; /** Interop call adapter for argument and return value validation and conversion. */ @@ -63,7 +62,6 @@ public WasmFunction(SymbolTable symbolTable, int index, int typeIndex, ImportDes this.index = index; this.importDescriptor = importDescriptor; this.typeIndex = typeIndex; - this.typeEquivalenceClass = -1; } public String moduleName() { @@ -74,7 +72,7 @@ public int paramCount() { return symbolTable.functionTypeParamCount(typeIndex); } - public byte paramTypeAt(int argumentIndex) { + public int paramTypeAt(int argumentIndex) { return symbolTable.functionTypeParamTypeAt(typeIndex, argumentIndex); } @@ -82,14 +80,10 @@ public int resultCount() { return symbolTable.functionTypeResultCount(typeIndex); } - public byte resultTypeAt(int returnIndex) { + public int resultTypeAt(int returnIndex) { return symbolTable.functionTypeResultTypeAt(typeIndex, returnIndex); } - void setTypeEquivalenceClass(int typeEquivalenceClass) { - this.typeEquivalenceClass = typeEquivalenceClass; - } - @Override public String toString() { return name(); @@ -146,8 +140,8 @@ public SymbolTable.FunctionType type() { return symbolTable.typeAt(typeIndex()); } - public int typeEquivalenceClass() { - return typeEquivalenceClass; + public SymbolTable.ClosedFunctionType closedType() { + return symbolTable.closedFunctionTypeAt(typeIndex()); } public int index() { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java index 107f9085a63c..e359b04a3846 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java @@ -111,7 +111,7 @@ static List recreateLinkActions(WasmModule module) { final EconomicMap importedGlobals = module.importedGlobals(); for (int i = 0; i < module.numGlobals(); i++) { final int globalIndex = i; - final byte globalValueType = module.globalValueType(globalIndex); + final int globalValueType = module.globalValueType(globalIndex); final byte globalMutability = module.globalMutability(globalIndex); if (importedGlobals.containsKey(globalIndex)) { final ImportDescriptor globalDescriptor = importedGlobals.get(globalIndex); @@ -139,7 +139,7 @@ static List recreateLinkActions(WasmModule module) { final int tableIndex = i; final int tableMinSize = module.tableInitialSize(tableIndex); final int tableMaxSize = module.tableMaximumSize(tableIndex); - final byte tableElemType = module.tableElementType(tableIndex); + final int tableElemType = module.tableElementType(tableIndex); final ImportDescriptor tableDescriptor = module.importedTable(tableIndex); if (tableDescriptor != null) { linkActions.add((context, store, instance, imports) -> { @@ -331,11 +331,22 @@ static List recreateLinkActions(WasmModule module) { final int elemIndex = i; final int elemOffset = module.elemInstanceOffset(elemIndex); final int encoding = bytecode[elemOffset]; - final int typeAndMode = bytecode[elemOffset + 1]; + final int typeLengthAndMode = bytecode[elemOffset + 1]; int effectiveOffset = elemOffset + 2; - final int elemMode = typeAndMode & BytecodeBitEncoding.ELEM_SEG_MODE_VALUE; + final int elemMode = typeLengthAndMode & BytecodeBitEncoding.ELEM_SEG_MODE_VALUE; + switch (typeLengthAndMode & BytecodeBitEncoding.ELEM_SEG_TYPE_MASK) { + case BytecodeBitEncoding.ELEM_SEG_TYPE_I8: + effectiveOffset++; + break; + case BytecodeBitEncoding.ELEM_SEG_TYPE_I16: + effectiveOffset += 2; + break; + case BytecodeBitEncoding.ELEM_SEG_TYPE_I32: + effectiveOffset += 4; + break; + } final int elemCount; switch (encoding & BytecodeBitEncoding.ELEM_SEG_COUNT_MASK) { case BytecodeBitEncoding.ELEM_SEG_COUNT_U8: @@ -500,7 +511,7 @@ private static void resolveInstantiatedCodeEntries(WasmStore store, WasmInstance } } - private static FrameDescriptor createFrameDescriptor(byte[] localTypes, int maxStackSize) { + private static FrameDescriptor createFrameDescriptor(int[] localTypes, int maxStackSize) { FrameDescriptor.Builder builder = FrameDescriptor.newBuilder(localTypes.length); builder.addSlots(localTypes.length + maxStackSize, FrameSlotKind.Static); return builder.build(); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java index 6e009ea98ce9..c5c6f90cd6aa 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java @@ -95,26 +95,8 @@ public final class WasmLanguage extends TruffleLanguage { private final Map builtinModules = new ConcurrentHashMap<>(); - private final Map equivalenceClasses = new ConcurrentHashMap<>(); - private int nextEquivalenceClass = SymbolTable.FIRST_EQUIVALENCE_CLASS; private final Map interopCallAdapters = new ConcurrentHashMap<>(); - public int equivalenceClassFor(SymbolTable.FunctionType type) { - CompilerAsserts.neverPartOfCompilation(); - Integer equivalenceClass = equivalenceClasses.get(type); - if (equivalenceClass == null) { - synchronized (this) { - equivalenceClass = equivalenceClasses.get(type); - if (equivalenceClass == null) { - equivalenceClass = nextEquivalenceClass++; - Integer prev = equivalenceClasses.put(type, equivalenceClass); - assert prev == null; - } - } - } - return equivalenceClass; - } - /** * Gets or creates the interop call adapter for a function type. Always returns the same call * target for any particular type. diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java index dc5d0813bd05..2473e9f94bea 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java @@ -65,7 +65,7 @@ public final class WasmTable extends EmbedderDataHolder implements TruffleObject /** * @see #elemType() */ - private final byte elemType; + private final int elemType; /** * @see #minSize() @@ -86,7 +86,7 @@ public final class WasmTable extends EmbedderDataHolder implements TruffleObject private Object[] elements; @TruffleBoundary - private WasmTable(int declaredMinSize, int declaredMaxSize, int initialSize, int maxAllowedSize, byte elemType, Object initialValue) { + private WasmTable(int declaredMinSize, int declaredMaxSize, int initialSize, int maxAllowedSize, int elemType, Object initialValue) { assert compareUnsigned(declaredMinSize, initialSize) <= 0; assert compareUnsigned(initialSize, maxAllowedSize) <= 0; assert compareUnsigned(maxAllowedSize, declaredMaxSize) <= 0; @@ -103,11 +103,11 @@ private WasmTable(int declaredMinSize, int declaredMaxSize, int initialSize, int this.elemType = elemType; } - public WasmTable(int declaredMinSize, int declaredMaxSize, int maxAllowedSize, byte elemType) { + public WasmTable(int declaredMinSize, int declaredMaxSize, int maxAllowedSize, int elemType) { this(declaredMinSize, declaredMaxSize, declaredMinSize, maxAllowedSize, elemType, WasmConstant.NULL); } - public WasmTable(int declaredMinSize, int declaredMaxSize, int maxAllowedSize, byte elemType, Object initialValue) { + public WasmTable(int declaredMinSize, int declaredMaxSize, int maxAllowedSize, int elemType, Object initialValue) { this(declaredMinSize, declaredMaxSize, declaredMinSize, maxAllowedSize, elemType, initialValue); } @@ -156,9 +156,10 @@ public int declaredMaxSize() { *

* This table can only be imported with an equivalent elem type. * - * @return Either {@link WasmType#FUNCREF_TYPE} or {@link WasmType#EXTERNREF_TYPE}. + * @return Either {@link WasmType#FUNCREF_TYPE}, {@link WasmType#EXTERNREF_TYPE} or some + * concrete reference type. */ - public byte elemType() { + public int elemType() { return elemType; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java index 9edc05402c12..9066d1e7d358 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java @@ -40,9 +40,6 @@ */ package org.graalvm.wasm; -import org.graalvm.wasm.exception.Failure; -import org.graalvm.wasm.exception.WasmException; - import com.oracle.truffle.api.CompilerAsserts; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; @@ -55,39 +52,63 @@ @ExportLibrary(InteropLibrary.class) @SuppressWarnings({"unused", "static-method"}) public class WasmType implements TruffleObject { - public static final byte VOID_TYPE = 0x40; - @CompilationFinal(dimensions = 1) public static final byte[] VOID_TYPE_ARRAY = {}; - public static final byte NULL_TYPE = 0x00; - - public static final byte UNKNOWN_TYPE = -1; + private static final int TYPE_INDEX_BITS = 30; + private static final int TYPE_VALUE_MASK = (1 << TYPE_INDEX_BITS) - 1; + private static final int TYPE_NULLABLE_MASK = 1 << TYPE_INDEX_BITS; + private static final int TYPE_PREDEFINED_MASK = 1 << (TYPE_INDEX_BITS + 1); + public static final int MAX_TYPE_INDEX = TYPE_VALUE_MASK; /** * Number Types. */ - public static final byte I32_TYPE = 0x7F; - @CompilationFinal(dimensions = 1) public static final byte[] I32_TYPE_ARRAY = {I32_TYPE}; + public static final int I32_TYPE = -0x01; + @CompilationFinal(dimensions = 1) public static final int[] I32_TYPE_ARRAY = {I32_TYPE}; - public static final byte I64_TYPE = 0x7E; - @CompilationFinal(dimensions = 1) public static final byte[] I64_TYPE_ARRAY = {I64_TYPE}; + public static final int I64_TYPE = -0x02; + @CompilationFinal(dimensions = 1) public static final int[] I64_TYPE_ARRAY = {I64_TYPE}; - public static final byte F32_TYPE = 0x7D; - @CompilationFinal(dimensions = 1) public static final byte[] F32_TYPE_ARRAY = {F32_TYPE}; + public static final int F32_TYPE = -0x03; + @CompilationFinal(dimensions = 1) public static final int[] F32_TYPE_ARRAY = {F32_TYPE}; - public static final byte F64_TYPE = 0x7C; - @CompilationFinal(dimensions = 1) public static final byte[] F64_TYPE_ARRAY = {F64_TYPE}; + public static final int F64_TYPE = -0x04; + @CompilationFinal(dimensions = 1) public static final int[] F64_TYPE_ARRAY = {F64_TYPE}; - public static final byte V128_TYPE = 0x7B; - @CompilationFinal(dimensions = 1) public static final byte[] V128_TYPE_ARRAY = {V128_TYPE}; + /** + * Vector Type. + */ + public static final int V128_TYPE = -0x05; + @CompilationFinal(dimensions = 1) public static final int[] V128_TYPE_ARRAY = {V128_TYPE}; /** * Reference Types. */ - public static final byte FUNCREF_TYPE = 0x70; - @CompilationFinal(dimensions = 1) public static final byte[] FUNCREF_TYPE_ARRAY = {FUNCREF_TYPE}; - public static final byte EXTERNREF_TYPE = 0x6F; - @CompilationFinal(dimensions = 1) public static final byte[] EXTERNREF_TYPE_ARRAY = {EXTERNREF_TYPE}; - public static final byte EXNREF_TYPE = 0x69; - @CompilationFinal(dimensions = 1) public static final byte[] EXNREF_TYPE_ARRAY = {EXNREF_TYPE}; + public static final int FUNC_HEAPTYPE = -0x10; + public static final int EXTERN_HEAPTYPE = -0x11; + public static final int EXN_HEAPTYPE = -0x17; + + public static final int FUNCREF_TYPE = FUNC_HEAPTYPE; + @CompilationFinal(dimensions = 1) public static final int[] FUNCREF_TYPE_ARRAY = {FUNCREF_TYPE}; + + public static final int EXTERNREF_TYPE = EXTERN_HEAPTYPE; + @CompilationFinal(dimensions = 1) public static final int[] EXTERNREF_TYPE_ARRAY = {EXTERNREF_TYPE}; + + public static final int EXNREF_TYPE = EXN_HEAPTYPE; + @CompilationFinal(dimensions = 1) public static final int[] EXNREF_TYPE_ARRAY = {EXNREF_TYPE}; + + /** + * Implementation-specific Types. + */ + public static final int NULL_TYPE = -0x7e; + public static final int UNKNOWN_TYPE = -0x7f; + + /** + * Bytes used in the binary encoding of types. + */ + public static final byte REF_TYPE_HEADER = -0x1c; + public static final byte REF_NULL_TYPE_HEADER = -0x1d; + // -0x40 is what the void block type byte 0x40 looks like when read as a signed LEB128 value. + public static final byte VOID_BLOCK_TYPE = -0x40; + @CompilationFinal(dimensions = 1) public static final int[] VOID_TYPE_ARRAY = {}; public static final WasmType VOID = new WasmType("void"); public static final WasmType NULL = new WasmType("null"); @@ -103,45 +124,73 @@ public class WasmType implements TruffleObject { public static String toString(int valueType) { CompilerAsserts.neverPartOfCompilation(); - switch (valueType) { - case I32_TYPE: - return "i32"; - case I64_TYPE: - return "i64"; - case F32_TYPE: - return "f32"; - case F64_TYPE: - return "f64"; - case V128_TYPE: - return "v128"; - case VOID_TYPE: - return "void"; - case FUNCREF_TYPE: - return "funcref"; - case EXTERNREF_TYPE: - return "externref"; - case EXNREF_TYPE: - return "exnref"; - default: - throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(valueType)); - } + return switch (valueType) { + case I32_TYPE -> "i32"; + case I64_TYPE -> "i64"; + case F32_TYPE -> "f32"; + case F64_TYPE -> "f64"; + case V128_TYPE -> "v128"; + default -> { + assert WasmType.isReferenceType(valueType); + boolean nullable = WasmType.isNullable(valueType); + yield switch (WasmType.getAbstractHeapType(valueType)) { + case FUNC_HEAPTYPE -> nullable ? "funcref" : "(ref func)"; + case EXTERN_HEAPTYPE -> nullable ? "externref" : "(ref extern)"; + case EXN_HEAPTYPE -> nullable ? "exnref" : "(ref exn)"; + default -> { + assert WasmType.isConcreteReferenceType(valueType); + StringBuilder sb = new StringBuilder(16); + sb.append("(ref "); + if (nullable) { + sb.append("null "); + } + sb.append(getTypeIndex(valueType)); + sb.append(")"); + yield sb.toString(); + } + }; + } + }; } - public static boolean isNumberType(byte type) { + public static boolean isNumberType(int type) { return type == I32_TYPE || type == I64_TYPE || type == F32_TYPE || type == F64_TYPE || type == UNKNOWN_TYPE; } - public static boolean isVectorType(byte type) { + public static boolean isVectorType(int type) { return type == V128_TYPE || type == UNKNOWN_TYPE; } - public static boolean isReferenceType(byte type) { - return type == FUNCREF_TYPE || type == EXTERNREF_TYPE || type == EXNREF_TYPE || type == UNKNOWN_TYPE; + public static boolean isReferenceType(int type) { + return isConcreteReferenceType(type) || withNullable(true, type) == FUNC_HEAPTYPE || withNullable(true, type) == EXTERN_HEAPTYPE || withNullable(true, type) == EXN_HEAPTYPE || type == UNKNOWN_TYPE; + } + + public static boolean isConcreteReferenceType(int type) { + return type >= 0; + } + + public static int getTypeIndex(int type) { + assert isConcreteReferenceType(type); + return type & TYPE_VALUE_MASK; + } + + public static int getAbstractHeapType(int type) { + assert isReferenceType(type); + return withNullable(true, type); + } + + public static boolean isNullable(int type) { + assert isReferenceType(type); + return (type & TYPE_NULLABLE_MASK) != 0; + } + + public static int withNullable(boolean nullable, int type) { + return nullable ? type | TYPE_NULLABLE_MASK : type & ~TYPE_NULLABLE_MASK; } - public static int getCommonValueType(byte[] types) { + public static int getCommonValueType(int[] types) { int type = 0; - for (byte resultType : types) { + for (int resultType : types) { type |= WasmType.isNumberType(resultType) ? NUM_COMMON_TYPE : 0; type |= WasmType.isVectorType(resultType) || WasmType.isReferenceType(resultType) ? OBJ_COMMON_TYPE : 0; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java index 3e9440eac6af..f4b79c91106f 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java @@ -108,7 +108,7 @@ public Object execute(VirtualFrame frame) { if (resultCount == 0) { return WasmConstant.VOID; } else if (resultCount == 1) { - byte resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, 0); + int resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, 0); return convertResult(result, resultType); } else { pushMultiValueResult(result, resultCount); @@ -125,16 +125,27 @@ public Object execute(VirtualFrame frame) { * of the correct boxed type because they're already converted on the JS side, so we only need * to unbox and can forego InteropLibrary. */ - private Object convertResult(Object result, byte resultType) throws UnsupportedMessageException { + private Object convertResult(Object result, int resultType) throws UnsupportedMessageException { CompilerAsserts.partialEvaluationConstant(resultType); return switch (resultType) { case WasmType.I32_TYPE -> asInt(result); case WasmType.I64_TYPE -> asLong(result); case WasmType.F32_TYPE -> asFloat(result); case WasmType.F64_TYPE -> asDouble(result); - case WasmType.V128_TYPE, WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> result; + case WasmType.V128_TYPE -> { + if (!(result instanceof Vector128)) { + errorBranch.enter(); + throw WasmException.create(Failure.TYPE_MISMATCH); + } + yield result; + } default -> { - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown result type: %d", resultType); + assert WasmType.isReferenceType(resultType); + if (!WasmType.isNullable(resultType) && result == WasmConstant.NULL) { + errorBranch.enter(); + throw WasmException.create(Failure.TYPE_MISMATCH); + } + yield result; } }; } @@ -157,7 +168,7 @@ private void pushMultiValueResult(Object result, int resultCount) { final long[] primitiveMultiValueStack = multiValueStack.primitiveStack(); final Object[] objectMultiValueStack = multiValueStack.objectStack(); for (int i = 0; i < resultCount; i++) { - byte resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, i); + int resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(resultType); Object value = arrayInterop.readArrayElement(result, i); switch (resultType) { @@ -172,10 +183,12 @@ private void pushMultiValueResult(Object result, int resultCount) { } objectMultiValueStack[i] = value; } - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> objectMultiValueStack[i] = value; default -> { - errorBranch.enter(); - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown result type: %d", resultType); + assert WasmType.isReferenceType(resultType); + if (!WasmType.isNullable(resultType) && value == WasmConstant.NULL) { + throw WasmException.create(Failure.INVALID_TYPE_IN_MULTI_VALUE); + } + objectMultiValueStack[i] = value; } } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java index 560eedc8eb81..568a234ad6b2 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java @@ -85,17 +85,17 @@ private static ValueType[] parseTypeString(String typesString, int start, int en } public static FuncType fromFunctionType(SymbolTable.FunctionType functionType) { - final byte[] paramTypes = functionType.paramTypes(); - final byte[] resultTypes = functionType.resultTypes(); + final int[] paramTypes = functionType.paramTypes(); + final int[] resultTypes = functionType.resultTypes(); final ValueType[] params = new ValueType[paramTypes.length]; final ValueType[] results = new ValueType[resultTypes.length]; for (int i = 0; i < paramTypes.length; i++) { - params[i] = ValueType.fromByteValue(paramTypes[i]); + params[i] = ValueType.fromValue(paramTypes[i]); } for (int i = 0; i < resultTypes.length; i++) { - results[i] = ValueType.fromByteValue(resultTypes[i]); + results[i] = ValueType.fromValue(resultTypes[i]); } return new FuncType(params, results); } @@ -117,16 +117,16 @@ public int resultCount() { } public SymbolTable.FunctionType toFunctionType() { - final byte[] paramTypes = new byte[params.length]; - final byte[] resultTypes = new byte[results.length]; + final int[] paramTypes = new int[params.length]; + final int[] resultTypes = new int[results.length]; for (int i = 0; i < paramTypes.length; i++) { - paramTypes[i] = params[i].byteValue(); + paramTypes[i] = params[i].value(); } for (int i = 0; i < resultTypes.length; i++) { - resultTypes[i] = results[i].byteValue(); + resultTypes[i] = results[i].value(); } - return SymbolTable.FunctionType.create(paramTypes, resultTypes); + return SymbolTable.FunctionType.createClosed(paramTypes, resultTypes); } public String toString(StringBuilder b) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java index c74567a05517..15cf1c5e8f56 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java @@ -47,8 +47,6 @@ import org.graalvm.wasm.WasmFunctionInstance; import org.graalvm.wasm.WasmLanguage; import org.graalvm.wasm.WasmType; -import org.graalvm.wasm.exception.Failure; -import org.graalvm.wasm.exception.WasmException; import org.graalvm.wasm.exception.WasmRuntimeException; import org.graalvm.wasm.nodes.WasmIndirectCallNode; @@ -111,31 +109,31 @@ public Object execute(VirtualFrame frame) { } private Object[] validateArguments(Object[] arguments, int offset) throws ArityException, UnsupportedTypeException { - final byte[] paramTypes = functionType.paramTypes(); + final int[] paramTypes = functionType.paramTypes(); final int paramCount = paramTypes.length; CompilerAsserts.partialEvaluationConstant(paramCount); if (arguments.length - offset != paramCount) { throw ArityException.create(paramCount, paramCount, arguments.length - offset); } if (CompilerDirectives.inCompiledCode() && paramCount <= MAX_UNROLL) { - validateArgumentsUnroll(arguments, offset, paramTypes, paramCount); + validateArgumentsUnroll(arguments, offset, paramTypes, paramCount, functionType.symbolTable()); } else { for (int i = 0; i < paramCount; i++) { - validateArgument(arguments, offset, paramTypes, i); + validateArgument(arguments, offset, paramTypes, i, functionType.symbolTable()); } } return arguments; } @ExplodeLoop - private static void validateArgumentsUnroll(Object[] arguments, int offset, byte[] paramTypes, int paramCount) throws UnsupportedTypeException { + private static void validateArgumentsUnroll(Object[] arguments, int offset, int[] paramTypes, int paramCount, SymbolTable symbolTable) throws UnsupportedTypeException { for (int i = 0; i < paramCount; i++) { - validateArgument(arguments, offset, paramTypes, i); + validateArgument(arguments, offset, paramTypes, i, symbolTable); } } - private static void validateArgument(Object[] arguments, int offset, byte[] paramTypes, int i) throws UnsupportedTypeException { - byte paramType = paramTypes[i]; + private static void validateArgument(Object[] arguments, int offset, int[] paramTypes, int i, SymbolTable symbolTable) throws UnsupportedTypeException { + int paramType = paramTypes[i]; Object value = arguments[i + offset]; switch (paramType) { case WasmType.I32_TYPE -> { @@ -163,20 +161,34 @@ private static void validateArgument(Object[] arguments, int offset, byte[] para return; } } - case WasmType.FUNCREF_TYPE -> { - if (value instanceof WasmFunctionInstance || value == WasmConstant.NULL) { - return; - } - } - case WasmType.EXTERNREF_TYPE -> { - return; - } - case WasmType.EXNREF_TYPE -> { - if (value instanceof WasmRuntimeException || value == WasmConstant.NULL) { - return; + default -> { + assert WasmType.isReferenceType(paramType); + boolean nullable = WasmType.isNullable(paramType); + switch (WasmType.getAbstractHeapType(paramType)) { + case WasmType.FUNCREF_TYPE -> { + if (value instanceof WasmFunctionInstance || nullable && value == WasmConstant.NULL) { + return; + } + } + case WasmType.EXTERNREF_TYPE -> { + if (nullable || value != WasmConstant.NULL) { + return; + } + } + case WasmType.EXNREF_TYPE -> { + if (value instanceof WasmRuntimeException || nullable && value == WasmConstant.NULL) { + return; + } + } + default -> { + assert WasmType.isConcreteReferenceType(paramType); + SymbolTable.ClosedFunctionType functionType = symbolTable.closedFunctionTypeAt(WasmType.getTypeIndex(paramType)); + if (value instanceof WasmFunctionInstance instance && functionType.matches(instance.function().closedType()) || nullable && value == WasmConstant.NULL) { + return; + } + } } } - default -> throw WasmException.create(Failure.UNKNOWN_TYPE); } throw UnsupportedTypeException.create(arguments); } @@ -185,7 +197,7 @@ private Object multiValueStackAsArray(WasmLanguage language) { final var multiValueStack = language.multiValueStack(); final long[] primitiveMultiValueStack = multiValueStack.primitiveStack(); final Object[] objectMultiValueStack = multiValueStack.objectStack(); - final byte[] resultTypes = functionType.resultTypes(); + final int[] resultTypes = functionType.resultTypes(); final int resultCount = resultTypes.length; assert primitiveMultiValueStack.length >= resultCount; assert objectMultiValueStack.length >= resultCount; @@ -202,25 +214,25 @@ private Object multiValueStackAsArray(WasmLanguage language) { } @ExplodeLoop - private static void popMultiValueResultUnroll(Object[] values, long[] primitiveMultiValueStack, Object[] objectMultiValueStack, byte[] resultTypes, int resultCount) { + private static void popMultiValueResultUnroll(Object[] values, long[] primitiveMultiValueStack, Object[] objectMultiValueStack, int[] resultTypes, int resultCount) { for (int i = 0; i < resultCount; i++) { values[i] = popMultiValueResult(primitiveMultiValueStack, objectMultiValueStack, resultTypes, i); } } - private static Object popMultiValueResult(long[] primitiveMultiValueStack, Object[] objectMultiValueStack, byte[] resultTypes, int i) { - final byte resultType = resultTypes[i]; + private static Object popMultiValueResult(long[] primitiveMultiValueStack, Object[] objectMultiValueStack, int[] resultTypes, int i) { + final int resultType = resultTypes[i]; return switch (resultType) { case WasmType.I32_TYPE -> (int) primitiveMultiValueStack[i]; case WasmType.I64_TYPE -> primitiveMultiValueStack[i]; case WasmType.F32_TYPE -> Float.intBitsToFloat((int) primitiveMultiValueStack[i]); case WasmType.F64_TYPE -> Double.longBitsToDouble(primitiveMultiValueStack[i]); - case WasmType.V128_TYPE, WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> { + default -> { + assert resultType == WasmType.V128_TYPE || WasmType.isReferenceType(resultType); Object obj = objectMultiValueStack[i]; objectMultiValueStack[i] = null; yield obj; } - default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL); }; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java index 3e565932388b..64f53c46a53f 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java @@ -47,18 +47,18 @@ public enum TableKind { anyfunc(WasmType.FUNCREF_TYPE), exnref(WasmType.EXNREF_TYPE); - private final byte byteValue; + private final int value; - TableKind(byte byteValue) { - this.byteValue = byteValue; + TableKind(int value) { + this.value = value; } - public byte byteValue() { - return byteValue; + public int value() { + return value; } - public static String toString(byte byteValue) { - return switch (byteValue) { + public static String toString(int value) { + return switch (value) { case WasmType.EXTERNREF_TYPE -> "externref"; case WasmType.FUNCREF_TYPE -> "anyfunc"; case WasmType.EXNREF_TYPE -> "exnref"; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java index 61c707315525..076b1d5942e3 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java @@ -56,13 +56,13 @@ public enum ValueType { externref(WasmType.EXTERNREF_TYPE), exnref(WasmType.EXNREF_TYPE); - private final byte byteValue; + private final int value; - ValueType(byte byteValue) { - this.byteValue = byteValue; + ValueType(int value) { + this.value = value; } - public static ValueType fromByteValue(byte value) { + public static ValueType fromValue(int value) { CompilerAsserts.neverPartOfCompilation(); return switch (value) { case WasmType.I32_TYPE -> i32; @@ -73,13 +73,12 @@ public static ValueType fromByteValue(byte value) { case WasmType.FUNCREF_TYPE -> anyfunc; case WasmType.EXTERNREF_TYPE -> externref; case WasmType.EXNREF_TYPE -> exnref; - default -> - throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(value)); + default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(value)); }; } - public byte byteValue() { - return byteValue; + public int value() { + return value; } public static boolean isNumberType(ValueType valueType) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/Vector128Shape.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/Vector128Shape.java index 388bad7a939a..e9ff2fe3c71d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/Vector128Shape.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/Vector128Shape.java @@ -53,15 +53,15 @@ public enum Vector128Shape { F32X4(WasmType.F32_TYPE, 4), F64X2(WasmType.F64_TYPE, 2); - private final byte unpackedType; + private final int unpackedType; private final int dimension; - Vector128Shape(byte unpackedType, int dimension) { + Vector128Shape(int unpackedType, int dimension) { this.unpackedType = unpackedType; this.dimension = dimension; } - public byte getUnpackedType() { + public int getUnpackedType() { return unpackedType; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java index b6d6916d920c..24221ff68fdd 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java @@ -267,7 +267,7 @@ public static Sequence moduleExports(WasmModule module) } else if (f != null) { list.add(new ModuleExportDescriptor(name, ImportExportKind.function.name(), WebAssembly.functionTypeToString(f))); } else if (globalIndex != null) { - String valueType = ValueType.fromByteValue(module.globalValueType(globalIndex)).toString(); + String valueType = ValueType.fromValue(module.globalValueType(globalIndex)).toString(); String mutability = module.isGlobalMutable(globalIndex) ? "mut" : "con"; list.add(new ModuleExportDescriptor(name, ImportExportKind.global.name(), valueType + " " + mutability)); } else if (tagIndex != null) { @@ -319,7 +319,7 @@ public static List moduleImportsAsList(WasmModule module break; case ImportIdentifier.GLOBAL: final Integer globalIndex = importedGlobalDescriptors.get(descriptor); - String valueType = ValueType.fromByteValue(module.globalValueType(globalIndex)).toString(); + String valueType = ValueType.fromValue(module.globalValueType(globalIndex)).toString(); list.add(new ModuleImportDescriptor(descriptor.moduleName(), descriptor.memberName(), ImportExportKind.global.name(), valueType)); break; case ImportIdentifier.TAG: @@ -449,7 +449,7 @@ public WasmTable tableAlloc(int initial, int maximum, TableKind elemKind, Object throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Element type must be anyfunc. Enable wasm.BulkMemoryAndRefTypes to support other reference types"); } final int maxAllowedSize = minUnsigned(maximum, JS_LIMITS.tableInstanceSizeLimit()); - return new WasmTable(initial, maximum, maxAllowedSize, elemKind.byteValue(), initialValue); + return new WasmTable(initial, maximum, maxAllowedSize, elemKind.value(), initialValue); } private static Object tableGrow(Object[] args) { @@ -571,7 +571,7 @@ public static String functionTypeToString(WasmFunction f) { if (i != 0) { typeInfo.append(' '); } - typeInfo.append(ValueType.fromByteValue(f.paramTypeAt(i))); + typeInfo.append(ValueType.fromValue(f.paramTypeAt(i))); } typeInfo.append(')'); @@ -580,7 +580,7 @@ public static String functionTypeToString(WasmFunction f) { if (i != 0) { typeInfo.append(' '); } - typeInfo.append(ValueType.fromByteValue(f.resultTypeAt(i))); + typeInfo.append(ValueType.fromValue(f.resultTypeAt(i))); } return typeInfo.toString(); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/ByteArrayList.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/ByteArrayList.java index 0ec417861399..f377ae6d2331 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/ByteArrayList.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/ByteArrayList.java @@ -40,9 +40,6 @@ */ package org.graalvm.wasm.collection; -import java.util.Arrays; -import java.util.NoSuchElementException; - public final class ByteArrayList { private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; @@ -60,20 +57,38 @@ public void add(byte b) { size++; } - public void push(byte b) { - add(b); - } - - public byte popBack() { - if (size() == 0) { - throw new NoSuchElementException("Cannot pop from an empty ByteArrayList."); + /** + * Add the little-endian LEB128 encoding of the signed {@code int} value to the list. + */ + public void addSignedInt32(int valueArg) { + int value = valueArg; + while (true) { + int b = value & 0x7f; + value >>= 7; + if ((value == 0 && (b & 0x40) == 0) || (value == -1 && (b & 0x40) != 0)) { + add((byte) b); + break; + } else { + add((byte) (b | 0x80)); + } } - size--; - return array[size]; } - public byte top() { - return array[size - 1]; + /** + * Add the little-endian LEB128 encoding of the unsigned {@code int} value to the list. + */ + public void addUnsignedInt32(int valueArg) { + int value = valueArg; + while (true) { + int b = value & 0x7f; + value >>>= 7; + if (value == 0) { + add((byte) b); + break; + } else { + add((byte) (b | 0x80)); + } + } } public void set(int index, byte b) { @@ -134,17 +149,4 @@ public byte[] toArray() { return EMPTY_BYTE_ARRAY; } } - - public static byte[] concat(ByteArrayList... byteArrayLists) { - int totalSize = Arrays.stream(byteArrayLists).mapToInt(ByteArrayList::size).sum(); - byte[] result = new byte[totalSize]; - int resultOffset = 0; - for (ByteArrayList byteArrayList : byteArrayLists) { - if (byteArrayList.array != null) { - System.arraycopy(byteArrayList.array, 0, result, resultOffset, byteArrayList.size); - resultOffset += byteArrayList.size; - } - } - return result; - } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/IntArrayList.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/IntArrayList.java index 460c45fff701..d54eefdb745a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/IntArrayList.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/collection/IntArrayList.java @@ -40,6 +40,8 @@ */ package org.graalvm.wasm.collection; +import java.util.NoSuchElementException; + public final class IntArrayList { public static final int[] EMPTY_INT_ARRAY = new int[0]; @@ -57,11 +59,22 @@ public void add(int n) { offset++; } + public void push(int b) { + add(b); + } + public int popBack() { + if (size() == 0) { + throw new NoSuchElementException("Cannot pop from an empty IntArrayList."); + } offset--; return array[offset]; } + public int top() { + return array[offset - 1]; + } + public int get(int index) { return array[index]; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java index 76fc7ee6ac92..7260ed3cea3f 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java @@ -137,9 +137,10 @@ public class BytecodeBitEncoding { public static final int ELEM_SEG_OFFSET_ADDRESS_U16 = 0b0000_0010; public static final int ELEM_SEG_OFFSET_ADDRESS_I32 = 0b0000_0011; - public static final int ELEM_SEG_TYPE_FUNREF = 0b0001_0000; - public static final int ELEM_SEG_TYPE_EXTERNREF = 0b0010_0000; - public static final int ELEM_SEG_TYPE_EXNREF = 0b0011_0000; + public static final int ELEM_SEG_TYPE_MASK = 0b0011_0000; + public static final int ELEM_SEG_TYPE_I8 = 0b0001_0000; + public static final int ELEM_SEG_TYPE_I16 = 0b0010_0000; + public static final int ELEM_SEG_TYPE_I32 = 0b0011_0000; public static final int ELEM_SEG_MODE_VALUE = 0b0000_1111; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index be1d4f83cd02..a5af133da711 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -582,14 +582,12 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown table element type: %s", element); } - int expectedTypeEquivalenceClass = symtab.equivalenceClass(expectedFunctionTypeIndex); - // Target function instance must be from the same context. assert functionInstanceContext == WasmContext.get(this); // Validate that the target function type matches the expected type of the // indirect call by performing an equivalence-class check. - if (expectedTypeEquivalenceClass != function.typeEquivalenceClass()) { + if (!symtab.closedTypeAt(expectedFunctionTypeIndex).matches(new SymbolTable.ClosedReferenceType(false, function.closedType()))) { enterErrorBranch(); failFunctionTypeCheck(function, expectedFunctionTypeIndex); } @@ -3359,7 +3357,7 @@ private void storeVectorLane(WasmMemory memory, WasmMemoryLibrary memoryLib, int // Checkstyle: stop method name check private void global_set(WasmInstance instance, VirtualFrame frame, int stackPointer, int index) { - final byte type = module.globalValueType(index); + final int type = module.globalValueType(index); CompilerAsserts.partialEvaluationConstant(type); // For global.set, we don't need to make sure that the referenced global is // mutable. @@ -3369,60 +3367,34 @@ private void global_set(WasmInstance instance, VirtualFrame frame, int stackPoin final GlobalRegistry globals = instance.globals(); switch (type) { - case WasmType.I32_TYPE: - globals.storeInt(globalAddress, popInt(frame, stackPointer)); - break; - case WasmType.F32_TYPE: - globals.storeFloat(globalAddress, popFloat(frame, stackPointer)); - break; - case WasmType.I64_TYPE: - globals.storeLong(globalAddress, popLong(frame, stackPointer)); - break; - case WasmType.F64_TYPE: - globals.storeDouble(globalAddress, popDouble(frame, stackPointer)); - break; - case WasmType.V128_TYPE: - globals.storeVector128(globalAddress, vector128Ops().toVector128(popVector128(frame, stackPointer))); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: + case WasmType.I32_TYPE -> globals.storeInt(globalAddress, popInt(frame, stackPointer)); + case WasmType.F32_TYPE -> globals.storeFloat(globalAddress, popFloat(frame, stackPointer)); + case WasmType.I64_TYPE -> globals.storeLong(globalAddress, popLong(frame, stackPointer)); + case WasmType.F64_TYPE -> globals.storeDouble(globalAddress, popDouble(frame, stackPointer)); + case WasmType.V128_TYPE -> globals.storeVector128(globalAddress, vector128Ops().toVector128(popVector128(frame, stackPointer))); + default -> { + assert WasmType.isReferenceType(type); globals.storeReference(globalAddress, popReference(frame, stackPointer)); - break; - default: - throw WasmException.create(Failure.UNSPECIFIED_TRAP, this, "Global variable cannot have the void type."); + } } } private void global_get(WasmInstance instance, VirtualFrame frame, int stackPointer, int index) { - final byte type = module.symbolTable().globalValueType(index); + final int type = module.symbolTable().globalValueType(index); CompilerAsserts.partialEvaluationConstant(type); final int globalAddress = module.symbolTable().globalAddress(index); final GlobalRegistry globals = instance.globals(); switch (type) { - case WasmType.I32_TYPE: - pushInt(frame, stackPointer, globals.loadAsInt(globalAddress)); - break; - case WasmType.F32_TYPE: - pushFloat(frame, stackPointer, globals.loadAsFloat(globalAddress)); - break; - case WasmType.I64_TYPE: - pushLong(frame, stackPointer, globals.loadAsLong(globalAddress)); - break; - case WasmType.F64_TYPE: - pushDouble(frame, stackPointer, globals.loadAsDouble(globalAddress)); - break; - case WasmType.V128_TYPE: - pushVector128(frame, stackPointer, vector128Ops().fromVector128(globals.loadAsVector128(globalAddress))); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: + case WasmType.I32_TYPE -> pushInt(frame, stackPointer, globals.loadAsInt(globalAddress)); + case WasmType.F32_TYPE -> pushFloat(frame, stackPointer, globals.loadAsFloat(globalAddress)); + case WasmType.I64_TYPE -> pushLong(frame, stackPointer, globals.loadAsLong(globalAddress)); + case WasmType.F64_TYPE -> pushDouble(frame, stackPointer, globals.loadAsDouble(globalAddress)); + case WasmType.V128_TYPE -> pushVector128(frame, stackPointer, vector128Ops().fromVector128(globals.loadAsVector128(globalAddress))); + default -> { + assert WasmType.isReferenceType(type); pushReference(frame, stackPointer, globals.loadAsReference(globalAddress)); - break; - default: - throw WasmException.create(Failure.UNSPECIFIED_TRAP, this, "Global variable cannot have the void type."); + } } } @@ -4547,7 +4519,7 @@ private Object[] createArgumentsForCall(VirtualFrame frame, int functionTypeInde int stackPointer = stackPointerOffset; for (int i = numArgs - 1; i >= 0; --i) { stackPointer--; - byte type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); + int type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(type); Object arg = switch (type) { case WasmType.I32_TYPE -> popInt(frame, stackPointer); @@ -4555,8 +4527,10 @@ private Object[] createArgumentsForCall(VirtualFrame frame, int functionTypeInde case WasmType.F32_TYPE -> popFloat(frame, stackPointer); case WasmType.F64_TYPE -> popDouble(frame, stackPointer); case WasmType.V128_TYPE -> vector128Ops().toVector128(popVector128(frame, stackPointer)); - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> popReference(frame, stackPointer); - default -> throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown type: %d", type); + default -> { + assert WasmType.isReferenceType(type); + yield popReference(frame, stackPointer); + } }; WasmArguments.setArgument(args, i, arg); } @@ -4577,7 +4551,7 @@ private Object[] createFieldsForException(VirtualFrame frame, int functionTypeIn int stackPointer = stackPointerOffset; for (int i = numFields - 1; i >= 0; --i) { stackPointer--; - byte type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); + int type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(type); final Object arg = switch (type) { case WasmType.I32_TYPE -> popInt(frame, stackPointer); @@ -4585,8 +4559,10 @@ private Object[] createFieldsForException(VirtualFrame frame, int functionTypeIn case WasmType.F32_TYPE -> popFloat(frame, stackPointer); case WasmType.F64_TYPE -> popDouble(frame, stackPointer); case WasmType.V128_TYPE -> vector128Ops().toVector128(popVector128(frame, stackPointer)); - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> popReference(frame, stackPointer); - default -> throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown type: %d", type); + default -> { + assert WasmType.isReferenceType(type); + yield popReference(frame, stackPointer); + } }; fields[i] = arg; } @@ -4601,7 +4577,7 @@ private void pushExceptionFields(VirtualFrame frame, WasmRuntimeException e, int final Object[] fields = e.fields(); int stackPointer = stackPointerOffset; for (int i = 0; i < numFields; i++) { - byte type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); + int type = module.symbolTable().functionTypeParamTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(type); switch (type) { case WasmType.I32_TYPE -> pushInt(frame, stackPointer, (int) fields[i]); @@ -4609,8 +4585,10 @@ private void pushExceptionFields(VirtualFrame frame, WasmRuntimeException e, int case WasmType.F32_TYPE -> pushFloat(frame, stackPointer, (float) fields[i]); case WasmType.F64_TYPE -> pushDouble(frame, stackPointer, (double) fields[i]); case WasmType.V128_TYPE -> pushVector128(frame, stackPointer, vector128Ops().fromVector128((Vector128) fields[i])); - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> pushReference(frame, stackPointer, fields[i]); - default -> throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown type: %d", type); + default -> { + assert WasmType.isReferenceType(type); + pushReference(frame, stackPointer, fields[i]); + } } stackPointer++; } @@ -4787,7 +4765,7 @@ private int pushDirectCallResult(VirtualFrame frame, int stackPointer, WasmFunct if (resultCount == 0) { return stackPointer; } else if (resultCount == 1) { - final byte resultType = function.resultTypeAt(0); + final int resultType = function.resultTypeAt(0); pushResult(frame, stackPointer, resultType, result); return stackPointer + 1; } else { @@ -4803,7 +4781,7 @@ private int pushIndirectCallResult(VirtualFrame frame, int stackPointer, int exp if (resultCount == 0) { return stackPointer; } else if (resultCount == 1) { - final byte resultType = module.symbolTable().functionTypeResultTypeAt(expectedFunctionTypeIndex, 0); + final int resultType = module.symbolTable().functionTypeResultTypeAt(expectedFunctionTypeIndex, 0); pushResult(frame, stackPointer, resultType, result); return stackPointer + 1; } else { @@ -4812,7 +4790,7 @@ private int pushIndirectCallResult(VirtualFrame frame, int stackPointer, int exp } } - private void pushResult(VirtualFrame frame, int stackPointer, byte resultType, Object result) { + private void pushResult(VirtualFrame frame, int stackPointer, int resultType, Object result) { CompilerAsserts.partialEvaluationConstant(resultType); switch (resultType) { case WasmType.I32_TYPE -> pushInt(frame, stackPointer, (int) result); @@ -4820,9 +4798,9 @@ private void pushResult(VirtualFrame frame, int stackPointer, byte resultType, O case WasmType.F32_TYPE -> pushFloat(frame, stackPointer, (float) result); case WasmType.F64_TYPE -> pushDouble(frame, stackPointer, (double) result); case WasmType.V128_TYPE -> pushVector128(frame, stackPointer, vector128Ops().fromVector128((Vector128) result)); - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> pushReference(frame, stackPointer, result); default -> { - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown result type: %d", resultType); + assert WasmType.isReferenceType(resultType); + pushReference(frame, stackPointer, result); } } } @@ -4845,7 +4823,7 @@ private void extractMultiValueResult(VirtualFrame frame, int stackPointer, Objec final long[] primitiveMultiValueStack = multiValueStack.primitiveStack(); final Object[] objectMultiValueStack = multiValueStack.objectStack(); for (int i = 0; i < resultCount; i++) { - final byte resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, i); + final int resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(resultType); switch (resultType) { case WasmType.I32_TYPE -> pushInt(frame, stackPointer + i, (int) primitiveMultiValueStack[i]); @@ -4856,14 +4834,11 @@ private void extractMultiValueResult(VirtualFrame frame, int stackPointer, Objec pushVector128(frame, stackPointer + i, vector128Ops().fromVector128((Vector128) objectMultiValueStack[i])); objectMultiValueStack[i] = null; } - case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE, WasmType.EXNREF_TYPE -> { + default -> { + assert WasmType.isReferenceType(resultType); pushReference(frame, stackPointer + i, objectMultiValueStack[i]); objectMultiValueStack[i] = null; } - default -> { - enterErrorBranch(); - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown result type: %d", resultType); - } } } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionRootNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionRootNode.java index d1a8bf480031..01000a040a76 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionRootNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionRootNode.java @@ -106,7 +106,7 @@ void enterErrorBranch() { codeEntry.errorBranch(); } - byte resultType(int index) { + int resultType(int index) { return codeEntry.resultType(index); } @@ -114,7 +114,7 @@ int paramCount() { return module().symbolTable().function(codeEntry.functionIndex()).paramCount(); } - byte localType(int index) { + int localType(int index) { return codeEntry.localType(index); } @@ -154,28 +154,19 @@ public Object executeWithInstance(VirtualFrame frame, WasmInstance instance) { if (resultCount == 0) { return WasmConstant.VOID; } else if (resultCount == 1) { - final byte resultType = resultType(0); + final int resultType = resultType(0); CompilerAsserts.partialEvaluationConstant(resultType); - switch (resultType) { - case WasmType.VOID_TYPE: - return WasmConstant.VOID; - case WasmType.I32_TYPE: - return popInt(frame, localCount); - case WasmType.I64_TYPE: - return popLong(frame, localCount); - case WasmType.F32_TYPE: - return popFloat(frame, localCount); - case WasmType.F64_TYPE: - return popDouble(frame, localCount); - case WasmType.V128_TYPE: - return Vector128Ops.SINGLETON_IMPLEMENTATION.toVector128(popVector128(frame, localCount)); - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: - return popReference(frame, localCount); - default: - throw WasmException.format(Failure.UNSPECIFIED_INTERNAL, this, "Unknown result type: %d", resultType); - } + return switch (resultType) { + case WasmType.I32_TYPE -> popInt(frame, localCount); + case WasmType.I64_TYPE -> popLong(frame, localCount); + case WasmType.F32_TYPE -> popFloat(frame, localCount); + case WasmType.F64_TYPE -> popDouble(frame, localCount); + case WasmType.V128_TYPE -> Vector128Ops.SINGLETON_IMPLEMENTATION.toVector128(popVector128(frame, localCount)); + default -> { + assert WasmType.isReferenceType(resultType); + yield popReference(frame, localCount); + } + }; } else { moveResultValuesToMultiValueStack(frame, resultCount, localCount); return WasmConstant.MULTI_VALUE; @@ -192,28 +183,15 @@ private void moveResultValuesToMultiValueStack(VirtualFrame frame, int resultCou final int resultType = resultType(i); CompilerAsserts.partialEvaluationConstant(resultType); switch (resultType) { - case WasmType.I32_TYPE: - primitiveMultiValueStack[i] = popInt(frame, localCount + i); - break; - case WasmType.I64_TYPE: - primitiveMultiValueStack[i] = popLong(frame, localCount + i); - break; - case WasmType.F32_TYPE: - primitiveMultiValueStack[i] = Float.floatToRawIntBits(popFloat(frame, localCount + i)); - break; - case WasmType.F64_TYPE: - primitiveMultiValueStack[i] = Double.doubleToRawLongBits(popDouble(frame, localCount + i)); - break; - case WasmType.V128_TYPE: - objectMultiValueStack[i] = Vector128Ops.SINGLETON_IMPLEMENTATION.toVector128(popVector128(frame, localCount + i)); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: + case WasmType.I32_TYPE -> primitiveMultiValueStack[i] = popInt(frame, localCount + i); + case WasmType.I64_TYPE -> primitiveMultiValueStack[i] = popLong(frame, localCount + i); + case WasmType.F32_TYPE -> primitiveMultiValueStack[i] = Float.floatToRawIntBits(popFloat(frame, localCount + i)); + case WasmType.F64_TYPE -> primitiveMultiValueStack[i] = Double.doubleToRawLongBits(popDouble(frame, localCount + i)); + case WasmType.V128_TYPE -> objectMultiValueStack[i] = Vector128Ops.SINGLETON_IMPLEMENTATION.toVector128(popVector128(frame, localCount + i)); + default -> { + assert WasmType.isReferenceType(resultType); objectMultiValueStack[i] = popReference(frame, localCount + i); - break; - default: - throw WasmException.format(Failure.UNSPECIFIED_INTERNAL, this, "Unknown result type: %d", resultType); + } } } } @@ -225,28 +203,17 @@ private void moveArgumentsToLocals(VirtualFrame frame) { assert WasmArguments.getArgumentCount(args) == paramCount : "Expected number of params " + paramCount + ", actual " + WasmArguments.getArgumentCount(args); for (int i = 0; i != paramCount; ++i) { final Object arg = WasmArguments.getArgument(args, i); - byte type = localType(i); + int type = localType(i); switch (type) { - case WasmType.I32_TYPE: - pushInt(frame, i, (int) arg); - break; - case WasmType.I64_TYPE: - pushLong(frame, i, (long) arg); - break; - case WasmType.F32_TYPE: - pushFloat(frame, i, (float) arg); - break; - case WasmType.F64_TYPE: - pushDouble(frame, i, (double) arg); - break; - case WasmType.V128_TYPE: - pushVector128(frame, i, Vector128Ops.SINGLETON_IMPLEMENTATION.fromVector128((Vector128) arg)); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: + case WasmType.I32_TYPE -> pushInt(frame, i, (int) arg); + case WasmType.I64_TYPE -> pushLong(frame, i, (long) arg); + case WasmType.F32_TYPE -> pushFloat(frame, i, (float) arg); + case WasmType.F64_TYPE -> pushDouble(frame, i, (double) arg); + case WasmType.V128_TYPE -> pushVector128(frame, i, Vector128Ops.SINGLETON_IMPLEMENTATION.fromVector128((Vector128) arg)); + default -> { + assert WasmType.isReferenceType(type); pushReference(frame, i, arg); - break; + } } } } @@ -255,28 +222,17 @@ private void moveArgumentsToLocals(VirtualFrame frame) { private void initializeLocals(VirtualFrame frame) { int paramCount = paramCount(); for (int i = paramCount; i != localCount(); ++i) { - byte type = localType(i); + int type = localType(i); switch (type) { - case WasmType.I32_TYPE: - pushInt(frame, i, 0); - break; - case WasmType.I64_TYPE: - pushLong(frame, i, 0L); - break; - case WasmType.F32_TYPE: - pushFloat(frame, i, 0F); - break; - case WasmType.F64_TYPE: - pushDouble(frame, i, 0D); - break; - case WasmType.V128_TYPE: - pushVector128(frame, i, Vector128Ops.SINGLETON_IMPLEMENTATION.fromVector128(Vector128.ZERO)); - break; - case WasmType.FUNCREF_TYPE: - case WasmType.EXTERNREF_TYPE: - case WasmType.EXNREF_TYPE: + case WasmType.I32_TYPE -> pushInt(frame, i, 0); + case WasmType.I64_TYPE -> pushLong(frame, i, 0L); + case WasmType.F32_TYPE -> pushFloat(frame, i, 0F); + case WasmType.F64_TYPE -> pushDouble(frame, i, 0D); + case WasmType.V128_TYPE -> pushVector128(frame, i, Vector128Ops.SINGLETON_IMPLEMENTATION.fromVector128(Vector128.ZERO)); + default -> { + WasmType.isReferenceType(type); pushReference(frame, i, WasmConstant.NULL); - break; + } } } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java index cafab53c9de7..fbe5115ce671 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java @@ -56,7 +56,8 @@ import org.graalvm.wasm.WasmInstance; import org.graalvm.wasm.WasmModule; import org.graalvm.wasm.WasmStore; -import org.graalvm.wasm.collection.ByteArrayList; +import org.graalvm.wasm.WasmType; +import org.graalvm.wasm.collection.IntArrayList; import org.graalvm.wasm.constants.Bytecode; import org.graalvm.wasm.constants.BytecodeBitEncoding; import org.graalvm.wasm.constants.SegmentMode; @@ -226,11 +227,22 @@ public static void resetTableState(WasmStore store, WasmModule module, WasmInsta for (int i = 0; i < module.elemInstanceCount(); i++) { final int elemOffset = module.elemInstanceOffset(i); final int flags = bytecode[elemOffset]; - final int typeAndMode = bytecode[elemOffset + 1]; + final int typeLengthAndMode = bytecode[elemOffset + 1]; int effectiveOffset = elemOffset + 2; - final int elemMode = typeAndMode & BytecodeBitEncoding.ELEM_SEG_MODE_VALUE; + final int elemMode = typeLengthAndMode & BytecodeBitEncoding.ELEM_SEG_MODE_VALUE; + switch (typeLengthAndMode & BytecodeBitEncoding.ELEM_SEG_TYPE_MASK) { + case BytecodeBitEncoding.ELEM_SEG_TYPE_I8: + effectiveOffset++; + break; + case BytecodeBitEncoding.ELEM_SEG_TYPE_I16: + effectiveOffset += 2; + break; + case BytecodeBitEncoding.ELEM_SEG_TYPE_I32: + effectiveOffset += 4; + break; + } final int elemCount; switch (flags & BytecodeBitEncoding.ELEM_SEG_COUNT_MASK) { case BytecodeBitEncoding.ELEM_SEG_COUNT_U8: @@ -406,26 +418,26 @@ public static CodeEntry readCodeEntry(WasmModule module, byte[] bytecode, int co default: throw CompilerDirectives.shouldNotReachHere(); } - final byte[] locals; + final int[] locals; if ((flags & BytecodeBitEncoding.CODE_ENTRY_LOCALS_FLAG) != 0) { - ByteArrayList localsList = new ByteArrayList(); - for (; bytecode[effectiveOffset] != 0; effectiveOffset++) { - localsList.add(bytecode[effectiveOffset]); + IntArrayList localsList = new IntArrayList(); + for (; bytecode[effectiveOffset] != 0; effectiveOffset += 4) { + localsList.add(BinaryStreamParser.peek4(bytecode, effectiveOffset)); } effectiveOffset++; locals = localsList.toArray(); } else { - locals = Bytecode.EMPTY_BYTES; + locals = WasmType.VOID_TYPE_ARRAY; } - final byte[] results; + final int[] results; if ((flags & BytecodeBitEncoding.CODE_ENTRY_RESULT_FLAG) != 0) { - ByteArrayList resultsList = new ByteArrayList(); - for (; bytecode[effectiveOffset] != 0; effectiveOffset++) { - resultsList.add(bytecode[effectiveOffset]); + IntArrayList resultsList = new IntArrayList(); + for (; bytecode[effectiveOffset] != 0; effectiveOffset += 4) { + resultsList.add(BinaryStreamParser.peek4(bytecode, effectiveOffset)); } results = resultsList.toArray(); } else { - results = Bytecode.EMPTY_BYTES; + results = WasmType.VOID_TYPE_ARRAY; } final int endOffset; if (exceptionTableOffset == BytecodeBitEncoding.INVALID_EXCEPTION_TABLE_OFFSET) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java index 6c000c36d1ca..36b17b0cb132 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java @@ -47,8 +47,6 @@ import org.graalvm.wasm.constants.BytecodeBitEncoding; import org.graalvm.wasm.constants.SegmentMode; -import com.oracle.truffle.api.CompilerDirectives; - /** * A data structure for generating the GraalWasm runtime bytecode. */ @@ -74,6 +72,10 @@ private static boolean fitsIntoUnsignedByte(long value) { return Long.compareUnsigned(value, 255) <= 0; } + private static boolean fitsIntoSignedShort(int value) { + return value >= Short.MIN_VALUE && value <= Short.MAX_VALUE; + } + private static boolean fitsIntoUnsignedShort(int value) { return Integer.compareUnsigned(value, 65535) <= 0; } @@ -297,7 +299,7 @@ public void addExtendedMemoryInstruction(int opcode, int memoryIndex, long offse * @param resultCount The number of results of the block. * @param stackSize The stack size at the start of the block. * @param commonResultType The most common result type of the result types of the block. See - * {@link WasmType#getCommonValueType(byte[])}. + * {@link WasmType#getCommonValueType(int[])}. * @return The location of the label in the bytecode. */ public int addLabel(int resultCount, int stackSize, int commonResultType) { @@ -342,7 +344,7 @@ public int addLabel(int resultCount, int stackSize, int commonResultType) { * @param resultCount The number of results of the loop. * @param stackSize The stack size at the start of the loop. * @param commonResultType The most common result type of the result types of the loop. See - * {@link WasmType#getCommonValueType(byte[])}. + * {@link WasmType#getCommonValueType(int[])}. * @return The location of the loop label in the bytecode. */ public int addLoopLabel(int resultCount, int stackSize, int commonResultType) { @@ -692,20 +694,26 @@ public void addDataRuntimeHeader(int length) { * @param offsetAddress The offset address of the elem segment, -1 if missing * @return The location after the header in the bytecode */ - public int addElemHeader(int mode, int count, byte elemType, int tableIndex, byte[] offsetBytecode, int offsetAddress) { + public int addElemHeader(int mode, int count, int elemType, int tableIndex, byte[] offsetBytecode, int offsetAddress) { assert offsetBytecode == null || offsetAddress == -1 : "elem header does not allow offset bytecode and offset address"; assert mode == SegmentMode.ACTIVE || mode == SegmentMode.PASSIVE || mode == SegmentMode.DECLARATIVE : "invalid segment mode in elem header"; assert WasmType.isReferenceType(elemType) : "invalid elem type in elem header"; - int location = location(); + int flagsLocation = location(); + add1(0); + int modeLocation = location(); add1(0); - final int type = switch (elemType) { - case WasmType.FUNCREF_TYPE -> BytecodeBitEncoding.ELEM_SEG_TYPE_FUNREF; - case WasmType.EXTERNREF_TYPE -> BytecodeBitEncoding.ELEM_SEG_TYPE_EXTERNREF; - case WasmType.EXNREF_TYPE -> BytecodeBitEncoding.ELEM_SEG_TYPE_EXNREF; - default -> throw CompilerDirectives.shouldNotReachHere(); - }; - add1(type | mode); + int typeLengthAndMode = mode; + if (fitsIntoSignedByte(elemType)) { + typeLengthAndMode |= BytecodeBitEncoding.ELEM_SEG_TYPE_I8; + add1(elemType); + } else if (fitsIntoSignedShort(elemType)) { + typeLengthAndMode |= BytecodeBitEncoding.ELEM_SEG_TYPE_I16; + add2(elemType); + } else { + typeLengthAndMode |= BytecodeBitEncoding.ELEM_SEG_TYPE_I32; + add4(elemType); + } int flags = 0; if (fitsIntoUnsignedByte(count)) { flags |= BytecodeBitEncoding.ELEM_SEG_COUNT_U8; @@ -754,7 +762,8 @@ public int addElemHeader(int mode, int count, byte elemType, int tableIndex, byt add4(offsetAddress); } } - set(location, (byte) flags); + set(flagsLocation, (byte) flags); + set(modeLocation, (byte) typeLengthAndMode); return location(); } @@ -767,6 +776,15 @@ public void addByte(byte value) { add1(value); } + /** + * Adds a value type to the bytecode. + * + * @param type The value type that should be added + */ + public void addType(int type) { + add4(type); + } + /** * Adds a null entry to the data of an elem segment. */ diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/ir/CodeEntry.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/ir/CodeEntry.java index 2e7624f7d01c..4dff60c0b04c 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/ir/CodeEntry.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/ir/CodeEntry.java @@ -49,15 +49,15 @@ public final class CodeEntry { private final int functionIndex; private final int maxStackSize; - private final byte[] localTypes; - private final byte[] resultTypes; + private final int[] localTypes; + private final int[] resultTypes; private final List callNodes; private final int bytecodeStartOffset; private final int bytecodeEndOffset; private final boolean usesMemoryZero; private final int exceptionTableOffset; - public CodeEntry(int functionIndex, int maxStackSize, byte[] localTypes, byte[] resultTypes, List callNodes, int startOffset, int endOffset, boolean usesMemoryZero, + public CodeEntry(int functionIndex, int maxStackSize, int[] localTypes, int[] resultTypes, List callNodes, int startOffset, int endOffset, boolean usesMemoryZero, int exceptionTableOffset) { this.functionIndex = functionIndex; this.maxStackSize = maxStackSize; @@ -78,11 +78,11 @@ public int functionIndex() { return functionIndex; } - public byte[] localTypes() { + public int[] localTypes() { return localTypes; } - public byte[] resultTypes() { + public int[] resultTypes() { return resultTypes; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java index 3f164c40c9d9..b2848fc9c212 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java @@ -55,14 +55,14 @@ class BlockFrame extends ControlFrame { private final IntArrayList branches; private final ArrayList exceptionHandlers; - BlockFrame(byte[] paramTypes, byte[] resultTypes, int initialStackSize, boolean unreachable) { + BlockFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, boolean unreachable) { super(paramTypes, resultTypes, initialStackSize, unreachable); branches = new IntArrayList(); exceptionHandlers = new ArrayList<>(); } @Override - byte[] labelTypes() { + int[] labelTypes() { return resultTypes(); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java index 4545dffbb890..6d545b4fb961 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java @@ -48,8 +48,8 @@ * Represents the scope of a block structure during module validation. */ public abstract class ControlFrame { - private final byte[] paramTypes; - private final byte[] resultTypes; + private final int[] paramTypes; + private final int[] resultTypes; private final int initialStackSize; private boolean unreachable; @@ -61,7 +61,7 @@ public abstract class ControlFrame { * @param initialStackSize The size of the value stack when entering this block structure. * @param unreachable If the block structure should be declared unreachable. */ - ControlFrame(byte[] paramTypes, byte[] resultTypes, int initialStackSize, boolean unreachable) { + ControlFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, boolean unreachable) { this.paramTypes = paramTypes; this.resultTypes = resultTypes; this.initialStackSize = initialStackSize; @@ -69,11 +69,11 @@ public abstract class ControlFrame { commonResultType = WasmType.getCommonValueType(resultTypes); } - protected byte[] paramTypes() { + protected int[] paramTypes() { return paramTypes; } - public byte[] resultTypes() { + public int[] resultTypes() { return resultTypes; } @@ -91,7 +91,7 @@ protected int commonResultType() { /** * @return The types that must be on the value stack when branching to this frame. */ - abstract byte[] labelTypes(); + abstract int[] labelTypes(); protected int labelTypeLength() { return labelTypes().length; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java index fef91318d510..01968552042f 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java @@ -59,7 +59,7 @@ class IfFrame extends ControlFrame { private int falseJumpLocation; private boolean elseBranch; - IfFrame(byte[] paramTypes, byte[] resultTypes, int initialStackSize, boolean unreachable, int falseJumpLocation) { + IfFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, boolean unreachable, int falseJumpLocation) { super(paramTypes, resultTypes, initialStackSize, unreachable); branchTargets = new IntArrayList(); exceptionHandlers = new ArrayList<>(); @@ -68,7 +68,7 @@ class IfFrame extends ControlFrame { } @Override - byte[] labelTypes() { + int[] labelTypes() { return resultTypes(); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java index 71b69f9281db..9ce7ecb83a4e 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java @@ -51,13 +51,13 @@ class LoopFrame extends ControlFrame { private final int labelLocation; - LoopFrame(byte[] paramTypes, byte[] resultTypes, int initialStackSize, boolean unreachable, int labelLocation) { + LoopFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, boolean unreachable, int labelLocation) { super(paramTypes, resultTypes, initialStackSize, unreachable); this.labelLocation = labelLocation; } @Override - byte[] labelTypes() { + int[] labelTypes() { return paramTypes(); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java index f2d4d3f5b307..b29c9f142709 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java @@ -48,7 +48,7 @@ import org.graalvm.wasm.Assert; import org.graalvm.wasm.WasmType; import org.graalvm.wasm.api.Vector128; -import org.graalvm.wasm.collection.ByteArrayList; +import org.graalvm.wasm.collection.IntArrayList; import org.graalvm.wasm.constants.Bytecode; import org.graalvm.wasm.exception.Failure; import org.graalvm.wasm.exception.WasmException; @@ -59,10 +59,10 @@ * additional information used to generate parser nodes. */ public class ParserState { - private static final byte[] EMPTY_ARRAY = new byte[0]; - private static final byte ANY = 0; + private static final int[] EMPTY_ARRAY = new int[0]; + private static final int ANY = 0; - private final ByteArrayList valueStack; + private final IntArrayList valueStack; private final ControlStack controlStack; private final RuntimeBytecodeGen bytecode; private final ArrayList exceptionTables; @@ -71,7 +71,7 @@ public class ParserState { private boolean usesMemoryZero; public ParserState(RuntimeBytecodeGen bytecode) { - this.valueStack = new ByteArrayList(); + this.valueStack = new IntArrayList(); this.controlStack = new ControlStack(); this.bytecode = bytecode; this.exceptionTables = new ArrayList<>(); @@ -85,7 +85,7 @@ public ParserState(RuntimeBytecodeGen bytecode) { * @param expectedValueType The expectedValueType used for error generation. * @return The top of the stack or -1. */ - private byte popInternal(byte expectedValueType) { + private int popInternal(int expectedValueType) { if (availableStackSize() == 0) { if (isCurrentStackUnreachable()) { return WasmType.UNKNOWN_TYPE; @@ -110,10 +110,10 @@ private byte popInternal(byte expectedValueType) { * @param expectedValueTypes Value types expected on the stack. * @return The maximum of available stack values smaller than the length of expectedValueTypes. */ - private byte[] popAvailableUnchecked(byte[] expectedValueTypes) { + private int[] popAvailableUnchecked(int[] expectedValueTypes) { int availableStackSize = availableStackSize(); int availableSize = Math.min(availableStackSize, expectedValueTypes.length); - byte[] popped = new byte[availableSize]; + int[] popped = new int[availableSize]; for (int i = availableSize - 1; i >= 0; i--) { popped[i] = popInternal(expectedValueTypes[i]); } @@ -125,9 +125,9 @@ private byte[] popAvailableUnchecked(byte[] expectedValueTypes) { * * @return The maximum of available stack values. */ - private byte[] popAvailableUnchecked() { + private int[] popAvailableUnchecked() { int availableStackSize = availableStackSize(); - byte[] popped = new byte[availableStackSize]; + int[] popped = new int[availableStackSize]; int j = 0; for (int i = availableStackSize - 1; i >= 0; i--) { popped[j] = popInternal(ANY); @@ -143,7 +143,7 @@ private byte[] popAvailableUnchecked() { * @param actualTypes The actual value types. * @return True if both are equivalent. */ - private boolean isTypeMismatch(byte[] expectedTypes, byte[] actualTypes) { + private boolean isTypeMismatch(int[] expectedTypes, int[] actualTypes) { if (expectedTypes.length != actualTypes.length) { return true; } @@ -163,8 +163,8 @@ private boolean isTypeMismatch(byte[] expectedTypes, byte[] actualTypes) { * * @param valueType The value type that should be added. */ - public void push(byte valueType) { - valueStack.push(valueType); + public void push(int valueType) { + valueStack.add(valueType); maxStackSize = Math.max(valueStack.size(), maxStackSize); } @@ -173,8 +173,8 @@ public void push(byte valueType) { * * @param valueTypes The value types that should be added. */ - public void pushAll(byte[] valueTypes) { - for (byte valueType : valueTypes) { + public void pushAll(int[] valueTypes) { + for (int valueType : valueTypes) { push(valueType); } } @@ -185,7 +185,7 @@ public void pushAll(byte[] valueTypes) { * @return The value type on top of the stack or -1. * @throws WasmException If the stack is empty. */ - public byte pop() { + public int pop() { return popInternal(ANY); } @@ -196,8 +196,8 @@ public byte pop() { * @return The value type on top of the stack. * @throws WasmException If the stack is empty or the value types do not match. */ - public byte popChecked(byte expectedValueType) { - final byte actualValueType = popInternal(expectedValueType); + public int popChecked(int expectedValueType) { + final int actualValueType = popInternal(expectedValueType); if (actualValueType != expectedValueType && actualValueType != WasmType.UNKNOWN_TYPE && expectedValueType != WasmType.UNKNOWN_TYPE) { throw ValidationErrors.createTypeMismatch(expectedValueType, actualValueType); } @@ -211,7 +211,7 @@ public byte popChecked(byte expectedValueType) { */ public void popReferenceTypeChecked() { if (availableStackSize() != 0) { - final byte value = valueStack.popBack(); + final int value = valueStack.popBack(); if (WasmType.isReferenceType(value)) { return; } @@ -230,8 +230,8 @@ public void popReferenceTypeChecked() { * @return The value types on top of the stack. * @throws WasmException If the stack is empty or the value types do not match. */ - public byte[] popAll(byte[] expectedValueTypes) { - byte[] popped = new byte[expectedValueTypes.length]; + public int[] popAll(int[] expectedValueTypes) { + int[] popped = new int[expectedValueTypes.length]; for (int i = expectedValueTypes.length - 1; i >= 0; i--) { popped[i] = popChecked(expectedValueTypes[i]); } @@ -264,7 +264,7 @@ private void unwindStack(int size) { } } - public void enterFunction(byte[] resultTypes) { + public void enterFunction(int[] resultTypes) { enterBlock(EMPTY_ARRAY, resultTypes); } @@ -275,7 +275,7 @@ public void enterFunction(byte[] resultTypes) { * @param paramTypes The param types of the block that was entered. * @param resultTypes The result types of the block that was entered. */ - public void enterBlock(byte[] paramTypes, byte[] resultTypes) { + public void enterBlock(int[] paramTypes, int[] resultTypes) { ControlFrame frame = new BlockFrame(paramTypes, resultTypes, valueStack.size(), false); controlStack.push(frame); pushAll(paramTypes); @@ -288,7 +288,7 @@ public void enterBlock(byte[] paramTypes, byte[] resultTypes) { * @param paramTypes The param types of the loop that was entered. * @param resultTypes The result types of the loop that was entered. */ - public void enterLoop(byte[] paramTypes, byte[] resultTypes) { + public void enterLoop(int[] paramTypes, int[] resultTypes) { final int label = bytecode.addLoopLabel(paramTypes.length, valueStack.size(), WasmType.getCommonValueType(resultTypes)); ControlFrame frame = new LoopFrame(paramTypes, resultTypes, valueStack.size(), false, label); controlStack.push(frame); @@ -302,7 +302,7 @@ public void enterLoop(byte[] paramTypes, byte[] resultTypes) { * @param paramTypes The param types of the if and else branch that was entered. * @param resultTypes The result type of the if and else branch that was entered. */ - public void enterIf(byte[] paramTypes, byte[] resultTypes) { + public void enterIf(int[] paramTypes, int[] resultTypes) { final int fixupLocation = bytecode.addIfLocation(); ControlFrame frame = new IfFrame(paramTypes, resultTypes, valueStack.size(), false, fixupLocation); controlStack.push(frame); @@ -326,7 +326,7 @@ public void enterElse() { * @param resultTypes The result types of the try table that was entered. * @param handlers The exception handlers of the try table that was entered. */ - public void enterTryTable(byte[] paramTypes, byte[] resultTypes, ExceptionHandler[] handlers) { + public void enterTryTable(int[] paramTypes, int[] resultTypes, ExceptionHandler[] handlers) { final TryTableFrame frame = new TryTableFrame(paramTypes, resultTypes, valueStack.size(), false, bytecode.location(), handlers); controlStack.push(frame); @@ -396,7 +396,7 @@ public void addSelectInstruction(int instruction) { public void addConditionalBranch(int branchLabel) { checkLabelExists(branchLabel); ControlFrame frame = getFrame(branchLabel); - final byte[] labelTypes = frame.labelTypes(); + final int[] labelTypes = frame.labelTypes(); popAll(labelTypes); pushAll(labelTypes); frame.addBranchIf(bytecode); @@ -411,7 +411,7 @@ public void addConditionalBranch(int branchLabel) { public void addUnconditionalBranch(int branchLabel) { checkLabelExists(branchLabel); ControlFrame frame = getFrame(branchLabel); - final byte[] labelTypes = frame.labelTypes(); + final int[] labelTypes = frame.labelTypes(); popAll(labelTypes); frame.addBranch(bytecode); } @@ -427,11 +427,11 @@ public void addBranchTable(int[] branchLabels) { int branchLabel = branchLabels[branchLabels.length - 1]; checkLabelExists(branchLabel); ControlFrame frame = getFrame(branchLabel); - byte[] branchLabelReturnTypes = frame.labelTypes(); + int[] branchLabelReturnTypes = frame.labelTypes(); for (int otherBranchLabel : branchLabels) { checkLabelExists(otherBranchLabel); frame = getFrame(otherBranchLabel); - byte[] otherBranchLabelReturnTypes = frame.labelTypes(); + int[] otherBranchLabelReturnTypes = frame.labelTypes(); checkLabelTypes(branchLabelReturnTypes, otherBranchLabelReturnTypes); pushAll(popAll(otherBranchLabelReturnTypes)); frame.addBranchTableItem(bytecode); @@ -633,10 +633,10 @@ public void addVectorLaneInstruction(int instruction, byte laneIndex) { * * @return The types of the return values of the current frame. */ - public byte[] exit(boolean multiValue) { + public int[] exit(boolean multiValue) { Assert.assertTrue(!controlStack.isEmpty(), Failure.UNEXPECTED_END_OF_BLOCK); ControlFrame frame = controlStack.peek(); - byte[] resultTypes = frame.resultTypes(); + int[] resultTypes = frame.resultTypes(); frame.exit(bytecode); checkStackAfterFrameExit(frame, resultTypes); @@ -653,9 +653,9 @@ public byte[] exit(boolean multiValue) { * @param frame The frame that is exited. * @param resultTypes The expected return types of the frame. */ - void checkStackAfterFrameExit(ControlFrame frame, byte[] resultTypes) { + void checkStackAfterFrameExit(ControlFrame frame, int[] resultTypes) { if (availableStackSize() > resultTypes.length) { - byte[] actualTypes = popAvailableUnchecked(); + int[] actualTypes = popAvailableUnchecked(); if (isTypeMismatch(resultTypes, actualTypes)) { throw ValidationErrors.createResultTypesMismatch(resultTypes, actualTypes); } @@ -687,11 +687,11 @@ public ControlFrame getRootBlock() { * stack. */ private void checkResultTypes(ControlFrame frame) { - byte[] resultTypes = frame.resultTypes(); + int[] resultTypes = frame.resultTypes(); if (isCurrentStackUnreachable()) { popAll(resultTypes); } else { - byte[] actualTypes = popAvailableUnchecked(resultTypes); + int[] actualTypes = popAvailableUnchecked(resultTypes); if (isTypeMismatch(resultTypes, actualTypes)) { throw ValidationErrors.createResultTypesMismatch(resultTypes, actualTypes); } @@ -705,11 +705,11 @@ private void checkResultTypes(ControlFrame frame) { * @throws WasmException If the parameter value types and the vale types on the stack do not * match. */ - public void checkParamTypes(byte[] paramTypes) { + public void checkParamTypes(int[] paramTypes) { if (isCurrentStackUnreachable()) { popAll(paramTypes); } else { - byte[] actualTypes = popAvailableUnchecked(paramTypes); + int[] actualTypes = popAvailableUnchecked(paramTypes); if (isTypeMismatch(paramTypes, actualTypes)) { throw ValidationErrors.createParamTypesMismatch(paramTypes, actualTypes); } @@ -735,7 +735,7 @@ public void checkLabelExists(int label) { * @param actualTypes The value types that should be checked. * @throws WasmException If the provided sets of value types do not match. */ - public void checkLabelTypes(byte[] expectedTypes, byte[] actualTypes) { + public void checkLabelTypes(int[] expectedTypes, int[] actualTypes) { if (isTypeMismatch(expectedTypes, actualTypes)) { throw ValidationErrors.createLabelTypesMismatch(expectedTypes, actualTypes); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java index 271b33f0c4be..90c34e5cd050 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java @@ -49,7 +49,7 @@ public class TryTableFrame extends BlockFrame { private final ExceptionTable table; - TryTableFrame(byte[] paramTypes, byte[] resultTypes, int initialStackSize, boolean unreachable, int startOffset, ExceptionHandler[] handlers) { + TryTableFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, boolean unreachable, int startOffset, ExceptionHandler[] handlers) { super(paramTypes, resultTypes, initialStackSize, unreachable); this.table = new ExceptionTable(startOffset, handlers); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java index 3e55f0dfec8c..ad62fd3d73c9 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java @@ -50,25 +50,10 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; public class ValidationErrors { - private static String getValueTypeString(byte valueType) { - return switch (valueType) { - case WasmType.VOID_TYPE -> ""; - case WasmType.I32_TYPE -> "i32"; - case WasmType.I64_TYPE -> "i64"; - case WasmType.F32_TYPE -> "f32"; - case WasmType.F64_TYPE -> "f64"; - case WasmType.V128_TYPE -> "v128"; - case WasmType.FUNCREF_TYPE -> "funcref"; - case WasmType.EXTERNREF_TYPE -> "externref"; - case WasmType.EXNREF_TYPE -> "exnref"; - default -> "unknown"; - }; - } - - private static String getValueTypesString(byte[] valueTypes) { + private static String getValueTypesString(int[] valueTypes) { StringJoiner stringJoiner = new StringJoiner(","); - for (byte valueType : valueTypes) { - stringJoiner.add(getValueTypeString(valueType)); + for (int valueType : valueTypes) { + stringJoiner.add(WasmType.toString(valueType)); } return stringJoiner.toString(); } @@ -78,28 +63,28 @@ private static WasmException create(String message, Object expected, Object actu } @TruffleBoundary - public static WasmException createTypeMismatch(byte expectedType, byte actualType) { - String expectedTypeString = getValueTypeString(expectedType); - String actualTypeString = getValueTypeString(actualType); + public static WasmException createTypeMismatch(int expectedType, int actualType) { + String expectedTypeString = WasmType.toString(expectedType); + String actualTypeString = WasmType.toString(actualType); return create("Expected type [%s], but got [%s].", expectedTypeString, actualTypeString); } @TruffleBoundary - public static WasmException createResultTypesMismatch(byte[] expectedTypes, byte[] actualTypes) { + public static WasmException createResultTypesMismatch(int[] expectedTypes, int[] actualTypes) { String expectedTypesString = getValueTypesString(expectedTypes); String actualTypesString = getValueTypesString(actualTypes); return create("Expected result types [%s], but got [%s].", expectedTypesString, actualTypesString); } @TruffleBoundary - public static WasmException createLabelTypesMismatch(byte[] expectedTypes, byte[] actualTypes) { + public static WasmException createLabelTypesMismatch(int[] expectedTypes, int[] actualTypes) { String expectedTypesString = getValueTypesString(expectedTypes); String actualTypesString = getValueTypesString(actualTypes); return create("Inconsistent label types. Expected [%s], but got [%s].", expectedTypesString, actualTypesString); } @TruffleBoundary - public static WasmException createParamTypesMismatch(byte[] expectedTypes, byte[] actualTypes) { + public static WasmException createParamTypesMismatch(int[] expectedTypes, int[] actualTypes) { String expectedTypesString = getValueTypesString(expectedTypes); String actualTypesString = getValueTypesString(actualTypes); return create("Expected param types [%s], but got [%s].", expectedTypesString, actualTypesString); @@ -121,8 +106,8 @@ public static WasmException createExpectedAnyOnEmptyStack() { } @TruffleBoundary - public static WasmException createExpectedTypeOnEmptyStack(byte expectedType) { - String expectedTypeString = getValueTypeString(expectedType); + public static WasmException createExpectedTypeOnEmptyStack(int expectedType) { + String expectedTypeString = WasmType.toString(expectedType); return create("Expected type [%s], but got [].", expectedTypeString, ""); } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java index d5390e6a9521..da2d53d956c0 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java @@ -97,7 +97,7 @@ public WasmInstance createInstance(WasmLanguage language, WasmStore store, Strin return instance; } - protected WasmFunction defineFunction(WasmContext context, WasmModule module, String name, byte[] paramTypes, byte[] retTypes, WasmRootNode rootNode) { + protected WasmFunction defineFunction(WasmContext context, WasmModule module, String name, int[] paramTypes, int[] retTypes, WasmRootNode rootNode) { // Must instantiate RootNode in the right language / sharing layer. assert context.language() == rootNode.getLanguage(WasmLanguage.class); // We could check if the same function type had already been allocated, @@ -109,24 +109,18 @@ protected WasmFunction defineFunction(WasmContext context, WasmModule module, St return function; } - protected int defineGlobal(WasmModule module, String name, byte valueType, byte mutability, Object value) { + protected int defineGlobal(WasmModule module, String name, int valueType, byte mutability, Object value) { int index = module.symbolTable().numGlobals(); module.symbolTable().declareExportedGlobalWithValue(name, index, valueType, mutability, value); return index; } - protected int defineTable(WasmContext context, WasmModule module, String tableName, int initSize, int maxSize, byte type) { + protected int defineTable(WasmContext context, WasmModule module, String tableName, int initSize, int maxSize, int type) { final boolean referenceTypes = context.getContextOptions().supportBulkMemoryAndRefTypes(); - switch (type) { - case WasmType.FUNCREF_TYPE: - break; - case WasmType.EXTERNREF_TYPE: - if (!referenceTypes) { - throw WasmException.create(Failure.UNSPECIFIED_MALFORMED, "Only function types are currently supported in tables."); - } - break; - default: - throw WasmException.create(Failure.MALFORMED_REFERENCE_TYPE, "Only reference types supported in tables."); + if (!WasmType.isReferenceType(type)) { + throw WasmException.create(Failure.MALFORMED_REFERENCE_TYPE, "Only reference types supported in tables."); + } else if (!referenceTypes && type != WasmType.FUNCREF_TYPE) { + throw WasmException.create(Failure.UNSPECIFIED_MALFORMED, "Only function types are currently supported in tables."); } int index = module.symbolTable().tableCount(); module.symbolTable().allocateTable(index, initSize, maxSize, type, referenceTypes); @@ -143,7 +137,7 @@ protected void defineMemory(WasmContext context, WasmModule module, String memor module.symbolTable().exportMemory(index, memoryName); } - protected void importFunction(WasmContext context, WasmModule module, String importModuleName, String importFunctionName, byte[] paramTypes, byte[] retTypes, String exportName) { + protected void importFunction(WasmContext context, WasmModule module, String importModuleName, String importFunctionName, int[] paramTypes, int[] retTypes, String exportName) { final int typeIdx = module.symbolTable().allocateFunctionType(paramTypes, retTypes, context.getContextOptions().supportMultiValue()); final WasmFunction function = module.symbolTable().importFunction(importModuleName, importFunctionName, typeIdx); module.symbolTable().exportFunction(function.index(), exportName); @@ -155,7 +149,7 @@ protected void importMemory(WasmContext context, WasmModule module, String impor module.symbolTable().importMemory(importModuleName, memoryName, index, initSize, maxSize, is64Bit, isShared, multiMemory); } - protected static byte[] types(byte... args) { + protected static int[] types(int... args) { return args; } } From 039f950f8b318cb7052771a287aaf0c8b73168b7 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 23 Sep 2025 12:33:13 +0200 Subject: [PATCH 02/40] Replace FunctionType with ClosedFunctionType --- .../src/org/graalvm/wasm/Assert.java | 6 - .../src/org/graalvm/wasm/BinaryParser.java | 8 +- .../src/org/graalvm/wasm/Linker.java | 9 +- .../src/org/graalvm/wasm/SymbolTable.java | 171 ++++++++---------- .../src/org/graalvm/wasm/WasmFunction.java | 14 +- .../org/graalvm/wasm/WasmInstantiator.java | 2 +- .../src/org/graalvm/wasm/WasmLanguage.java | 4 +- .../src/org/graalvm/wasm/WasmTag.java | 6 +- .../src/org/graalvm/wasm/api/FuncType.java | 22 +-- .../wasm/api/InteropCallAdapterNode.java | 133 ++++++++------ .../src/org/graalvm/wasm/api/ValueType.java | 44 ++++- .../src/org/graalvm/wasm/api/WebAssembly.java | 6 +- .../org/graalvm/wasm/globals/WasmGlobal.java | 9 +- 13 files changed, 229 insertions(+), 205 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Assert.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Assert.java index 55f94fc23eed..9d8e5aaa8fe5 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Assert.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Assert.java @@ -199,12 +199,6 @@ public static void assertTrue(boolean condition, String message, Failure failure } } - public static void assertFunctionTypeEquals(SymbolTable.FunctionType t1, SymbolTable.FunctionType t2, Failure failure) throws WasmException { - if (!t1.equals(t2)) { - fail(failure, "%s: %s should = %s", failure.name, t1, t2); - } - } - @TruffleBoundary public static RuntimeException fail(Failure failure, String format, Object... args) throws WasmException { throw WasmException.format(failure, format, args); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index 71664b4c55cf..337523bc7d97 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -783,7 +783,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType if (!multiValue) { assertIntLessOrEqual(function.resultCount(), 1, Failure.INVALID_RESULT_ARITY); } - state.pushAll(function.type().resultTypes()); + state.pushAll(function.resultTypes()); state.addCall(callNodes.size(), callFunctionIndex); callNodes.add(new CallNode(bytecode.location(), callFunctionIndex)); break; @@ -885,7 +885,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType checkExceptionHandlingSupport(opcode); final int tagIndex = readTagIndex(); final int typeIndex = module.tagTypeIndex(tagIndex); - final int[] paramTypes = module.typeAt(typeIndex).paramTypes(); + final int[] paramTypes = module.functionTypeParamTypesAsArray(typeIndex); state.popAll(paramTypes); state.addMiscFlag(); state.addInstruction(Bytecode.THROW, tagIndex); @@ -2542,7 +2542,7 @@ private ExceptionHandler[] readExceptionHandlers(ParserState state) { final int label = readUnsignedInt32(); assertUnsignedIntLess(label, state.controlStackSize(), Failure.INVALID_CATCH_CLAUSE_LABEL); final int typeIndex = module.tagTypeIndex(tag); - final int[] paramTypes = module.typeAt(typeIndex).paramTypes(); + final int[] paramTypes = module.functionTypeParamTypesAsArray(typeIndex); handlers[i] = state.enterCatchClause(opcode, tag, label); state.pushAll(paramTypes); } @@ -2551,7 +2551,7 @@ private ExceptionHandler[] readExceptionHandlers(ParserState state) { final int label = readUnsignedInt32(); assertUnsignedIntLess(label, state.controlStackSize(), Failure.INVALID_CATCH_CLAUSE_LABEL); final int typeIndex = module.tagTypeIndex(tag); - final int[] paramTypes = module.typeAt(typeIndex).paramTypes(); + final int[] paramTypes = module.functionTypeParamTypesAsArray(typeIndex); handlers[i] = state.enterCatchClause(opcode, tag, label); state.pushAll(paramTypes); state.push(EXNREF_TYPE); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 4c9049a51879..997d923a6e74 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -40,7 +40,6 @@ */ package org.graalvm.wasm; -import static org.graalvm.wasm.Assert.assertFunctionTypeEquals; import static org.graalvm.wasm.Assert.assertIntEqual; import static org.graalvm.wasm.Assert.assertTrue; import static org.graalvm.wasm.Assert.assertUnsignedIntGreaterOrEqual; @@ -402,7 +401,7 @@ void resolveFunctionImport(WasmStore store, WasmInstance instance, WasmFunction Object externalFunctionInstance = lookupImportObject(instance, importDescriptor, imports, Object.class); if (externalFunctionInstance != null) { if (externalFunctionInstance instanceof WasmFunctionInstance functionInstance) { - if (!function.type().equals(functionInstance.function().type())) { + if (!function.closedType().matches(functionInstance.function().closedType())) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE); } instance.setTarget(function.index(), functionInstance.target()); @@ -435,7 +434,7 @@ void resolveFunctionImport(WasmStore store, WasmInstance instance, WasmFunction throw WasmException.create(Failure.UNKNOWN_IMPORT, "The imported function '" + function.importedFunctionName() + "', referenced in the module '" + instance.name() + "', does not exist in the imported module '" + function.importedModuleName() + "'."); } - if (!function.type().equals(importedFunction.type())) { + if (!function.closedType().matches(importedFunction.closedType())) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE); } final CallTarget target = importedInstance.target(importedFunction.index()); @@ -511,7 +510,7 @@ void resolveMemoryExport(WasmInstance instance, int memoryIndex, String exported }); } - void resolveTagImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int tagIndex, SymbolTable.FunctionType type, ImportValueSupplier imports) { + void resolveTagImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int tagIndex, SymbolTable.ClosedFunctionType type, ImportValueSupplier imports) { final String importedModuleName = importDescriptor.moduleName(); final String importedTagName = importDescriptor.memberName(); final Runnable resolveAction = () -> { @@ -539,7 +538,7 @@ void resolveTagImport(WasmStore store, WasmInstance instance, ImportDescriptor i } importedTag = importedInstance.tag(exportedTagIndex); } - assertFunctionTypeEquals(type, importedTag.type(), Failure.INCOMPATIBLE_IMPORT_TYPE); + Assert.assertTrue(type.matches(importedTag.type()), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTag(tagIndex, importedTag); }; resolutionDag.resolveLater(new ImportTagSym(instance.name(), importDescriptor, tagIndex), new Sym[]{new ExportTagSym(importedModuleName, importedTagName)}, resolveAction); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index e2a6a9bbf3b6..8042bcf1f075 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -47,7 +47,6 @@ import static org.graalvm.wasm.WasmMath.minUnsigned; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import org.graalvm.collections.EconomicMap; @@ -85,11 +84,28 @@ public abstract class SymbolTable { public static final int UNINITIALIZED_ADDRESS = Integer.MIN_VALUE; public abstract static sealed class ClosedValueType { + // This is a workaround until we can use pattern matching in JDK 21+. + public enum Kind { + Number, + Vector, + Reference + } + public abstract boolean matches(ClosedValueType valueSubType); + + public abstract Kind kind(); } public abstract static sealed class ClosedHeapType { + // This is a workaround until we can use pattern matching in JDK 21+. + public enum Kind { + Abstract, + Function + } + public abstract boolean matches(ClosedHeapType heapSubType); + + public abstract Kind kind(); } public static final class NumberType extends ClosedValueType { @@ -104,10 +120,19 @@ private NumberType(int value) { this.value = value; } + public int value() { + return value; + } + @Override public boolean matches(ClosedValueType valueSubType) { return this == valueSubType; } + + @Override + public Kind kind() { + return Kind.Number; + } } public static final class VectorType extends ClosedValueType { @@ -119,10 +144,19 @@ private VectorType(int value) { this.value = value; } + public int value() { + return value; + } + @Override public boolean matches(ClosedValueType valueSubType) { return this == valueSubType; } + + @Override + public Kind kind() { + return Kind.Vector; + } } public static final class ClosedReferenceType extends ClosedValueType { @@ -130,6 +164,8 @@ public static final class ClosedReferenceType extends ClosedValueType { public static final ClosedReferenceType NONNULL_FUNCREF = new ClosedReferenceType(false, AbstractHeapType.FUNC); public static final ClosedReferenceType EXTERNREF = new ClosedReferenceType(true, AbstractHeapType.EXTERN); public static final ClosedReferenceType NONNULL_EXTERNREF = new ClosedReferenceType(false, AbstractHeapType.EXTERN); + public static final ClosedReferenceType EXNREF = new ClosedReferenceType(true, AbstractHeapType.EXN); + public static final ClosedReferenceType NONNULL_EXNREF = new ClosedReferenceType(false, AbstractHeapType.EXN); private final boolean nullable; private final ClosedHeapType closedHeapType; @@ -139,15 +175,29 @@ public ClosedReferenceType(boolean nullable, ClosedHeapType closedHeapType) { this.closedHeapType = closedHeapType; } + public boolean nullable() { + return nullable; + } + + public ClosedHeapType heapType() { + return closedHeapType; + } + @Override public boolean matches(ClosedValueType valueSubType) { return valueSubType instanceof ClosedReferenceType referenceSubType && (!referenceSubType.nullable || this.nullable) && this.closedHeapType.matches(referenceSubType.closedHeapType); } + + @Override + public Kind kind() { + return Kind.Reference; + } } public static final class AbstractHeapType extends ClosedHeapType { public static final AbstractHeapType FUNC = new AbstractHeapType(WasmType.FUNC_HEAPTYPE); public static final AbstractHeapType EXTERN = new AbstractHeapType(WasmType.EXTERN_HEAPTYPE); + public static final AbstractHeapType EXN = new AbstractHeapType(WasmType.EXN_HEAPTYPE); private final int value; @@ -155,112 +205,61 @@ private AbstractHeapType(int value) { this.value = value; } + public int value() { + return value; + } + @Override public boolean matches(ClosedHeapType heapSubType) { return switch (this.value) { case WasmType.FUNC_HEAPTYPE -> this == FUNC || heapSubType instanceof ClosedFunctionType; case WasmType.EXTERN_HEAPTYPE -> this == EXTERN; + case WasmType.EXN_HEAPTYPE -> this == EXN; default -> throw CompilerDirectives.shouldNotReachHere(); }; } - } - - public static final class ClosedFunctionType extends ClosedHeapType { - @CompilationFinal(dimensions = 1) private final ClosedValueType[] paramTypes; - @CompilationFinal(dimensions = 1) private final ClosedValueType[] resultTypes; - - ClosedFunctionType(ClosedValueType[] paramTypes, ClosedValueType[] resultTypes) { - this.paramTypes = paramTypes; - this.resultTypes = resultTypes; - } @Override - public boolean matches(ClosedHeapType heapSubType) { - if (!(heapSubType instanceof ClosedFunctionType functionSubType)) { - return false; - } - if (this.paramTypes.length != functionSubType.paramTypes.length) { - return false; - } - for (int i = 0; i < this.paramTypes.length; i++) { - if (!functionSubType.paramTypes[i].matches(this.paramTypes[i])) { - return false; - } - } - if (this.resultTypes.length != functionSubType.resultTypes.length) { - return false; - } - for (int i = 0; i < this.resultTypes.length; i++) { - if (!this.resultTypes[i].matches(functionSubType.resultTypes[i])) { - return false; - } - } - return true; + public Kind kind() { + return Kind.Abstract; } } - public static final class FunctionType { - @CompilationFinal(dimensions = 1) private final int[] paramTypes; - @CompilationFinal(dimensions = 1) private final int[] resultTypes; - private final SymbolTable symbolTable; - private final int hashCode; + public static final class ClosedFunctionType extends ClosedHeapType { + @CompilationFinal(dimensions = 1) private final ClosedValueType[] paramTypes; + @CompilationFinal(dimensions = 1) private final ClosedValueType[] resultTypes; - FunctionType(int[] paramTypes, int[] resultTypes, SymbolTable symbolTable) { + public ClosedFunctionType(ClosedValueType[] paramTypes, ClosedValueType[] resultTypes) { this.paramTypes = paramTypes; this.resultTypes = resultTypes; - this.symbolTable = symbolTable; - this.hashCode = Arrays.hashCode(paramTypes) ^ Arrays.hashCode(resultTypes) ^ Arrays.hashCode(symbolTable.typeData); } - public static FunctionType createClosed(int[] paramTypes, int[] resultTypes) { - assert allClosed(paramTypes) && allClosed(resultTypes); - return new FunctionType(paramTypes, resultTypes, null); - } - - private static boolean allClosed(int[] types) { - for (int type : types) { - if (WasmType.isConcreteReferenceType(type)) { - return false; - } - } - return true; - } - - public int[] paramTypes() { + public ClosedValueType[] paramTypes() { return paramTypes; } - public int[] resultTypes() { + public ClosedValueType[] resultTypes() { return resultTypes; } - public SymbolTable symbolTable() { - return symbolTable; - } - @Override - public int hashCode() { - return hashCode; - } - - @Override - public boolean equals(Object object) { - if (!(object instanceof FunctionType that)) { + public boolean matches(ClosedHeapType heapSubType) { + if (!(heapSubType instanceof ClosedFunctionType functionSubType)) { return false; } - if (this.paramTypes.length != that.paramTypes.length) { + if (this.paramTypes.length != functionSubType.paramTypes.length) { return false; } for (int i = 0; i < this.paramTypes.length; i++) { - if (this.paramTypes[i] != that.paramTypes[i]) { + if (!functionSubType.paramTypes[i].matches(this.paramTypes[i])) { return false; } } - if (this.resultTypes.length != that.resultTypes.length) { + if (this.resultTypes.length != functionSubType.resultTypes.length) { return false; } for (int i = 0; i < this.resultTypes.length; i++) { - if (this.resultTypes[i] != that.resultTypes[i]) { + if (!this.resultTypes[i].matches(functionSubType.resultTypes[i])) { return false; } } @@ -268,17 +267,8 @@ public boolean equals(Object object) { } @Override - public String toString() { - CompilerAsserts.neverPartOfCompilation(); - String[] paramNames = new String[paramTypes.length]; - for (int i = 0; i < paramTypes.length; i++) { - paramNames[i] = WasmType.toString(paramTypes[i]); - } - String[] resultNames = new String[resultTypes.length]; - for (int i = 0; i < resultTypes.length; i++) { - resultNames[i] = WasmType.toString(resultTypes[i]); - } - return "(" + String.join(" ", paramNames) + ")->(" + String.join(" ", resultNames) + ")"; + public Kind kind() { + return Kind.Function; } } @@ -775,7 +765,7 @@ public int functionTypeResultTypeAt(int typeIndex, int resultIndex) { return typeData[typeOffset + 2 + paramCount + resultIndex]; } - private int[] functionTypeParamTypesAsArray(int typeIndex) { + public int[] functionTypeParamTypesAsArray(int typeIndex) { int paramCount = functionTypeParamCount(typeIndex); int[] paramTypes = new int[paramCount]; for (int i = 0; i < paramCount; ++i) { @@ -784,7 +774,7 @@ private int[] functionTypeParamTypesAsArray(int typeIndex) { return paramTypes; } - private int[] functionTypeResultTypesAsArray(int typeIndex) { + public int[] functionTypeResultTypesAsArray(int typeIndex) { int resultTypeCount = functionTypeResultCount(typeIndex); int[] resultTypes = new int[resultTypeCount]; for (int i = 0; i < resultTypeCount; i++) { @@ -797,10 +787,6 @@ int typeCount() { return typeCount; } - public FunctionType typeAt(int index) { - return new FunctionType(functionTypeParamTypesAsArray(index), functionTypeResultTypesAsArray(index), this); - } - public ClosedFunctionType closedFunctionTypeAt(int typeIndex) { ClosedValueType[] paramTypes = new ClosedValueType[functionTypeParamCount(typeIndex)]; for (int i = 0; i < paramTypes.length; i++) { @@ -826,6 +812,7 @@ public ClosedValueType closedTypeAt(int type) { yield switch (WasmType.getAbstractHeapType(type)) { case WasmType.FUNC_HEAPTYPE -> nullable ? ClosedReferenceType.FUNCREF : ClosedReferenceType.NONNULL_FUNCREF; case WasmType.EXTERN_HEAPTYPE -> nullable ? ClosedReferenceType.EXTERNREF : ClosedReferenceType.NONNULL_EXTERNREF; + case WasmType.EXN_HEAPTYPE -> nullable ? ClosedReferenceType.EXNREF : ClosedReferenceType.NONNULL_EXNREF; default -> { assert WasmType.isConcreteReferenceType(type); yield new ClosedReferenceType(nullable, closedFunctionTypeAt(WasmType.getTypeIndex(type))); @@ -835,10 +822,6 @@ yield switch (WasmType.getAbstractHeapType(type)) { }; } - public boolean matches(int superType, int subType) { - return closedTypeAt(superType).matches(closedTypeAt(subType)); - } - public void importSymbol(ImportDescriptor descriptor) { checkNotParsed(); assert importedSymbols.size() == descriptor.importedSymbolIndex(); @@ -1313,7 +1296,7 @@ public void allocateTag(int index, byte attribute, int typeIndex) { checkNotParsed(); addTag(index, attribute, typeIndex); module().addLinkAction((context, store, instance, imports) -> { - final WasmTag tag = new WasmTag(typeAt(typeIndex)); + final WasmTag tag = new WasmTag(closedFunctionTypeAt(typeIndex)); instance.setTag(index, tag); }); } @@ -1322,7 +1305,7 @@ public void importTag(String moduleName, String tagName, int index, byte attribu checkNotParsed(); addTag(index, attribute, typeIndex); final ImportDescriptor importedTag = new ImportDescriptor(moduleName, tagName, ImportIdentifier.TAG, index, numImportedSymbols()); - final FunctionType type = typeAt(typeIndex); + final ClosedFunctionType type = closedFunctionTypeAt(typeIndex); importedTags.put(index, importedTag); importSymbol(importedTag); module().addLinkAction((context, store, instance, imports) -> { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java index 20bca810317d..d2af6a9c437b 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java @@ -76,6 +76,10 @@ public int paramTypeAt(int argumentIndex) { return symbolTable.functionTypeParamTypeAt(typeIndex, argumentIndex); } + public int[] paramTypes() { + return symbolTable.functionTypeParamTypesAsArray(typeIndex); + } + public int resultCount() { return symbolTable.functionTypeResultCount(typeIndex); } @@ -84,6 +88,10 @@ public int resultTypeAt(int returnIndex) { return symbolTable.functionTypeResultTypeAt(typeIndex, returnIndex); } + public int[] resultTypes() { + return symbolTable.functionTypeResultTypesAsArray(typeIndex); + } + @Override public String toString() { return name(); @@ -136,10 +144,6 @@ public int typeIndex() { return typeIndex; } - public SymbolTable.FunctionType type() { - return symbolTable.typeAt(typeIndex()); - } - public SymbolTable.ClosedFunctionType closedType() { return symbolTable.closedFunctionTypeAt(typeIndex()); } @@ -172,7 +176,7 @@ public CallTarget getOrCreateInteropCallAdapter(WasmLanguage language) { CallTarget callAdapter = this.interopCallAdapter; if (callAdapter == null) { // Benign initialization race: The call target will be the same each time. - callAdapter = language.interopCallAdapterFor(type()); + callAdapter = language.interopCallAdapterFor(closedType()); this.interopCallAdapter = callAdapter; } return callAdapter; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java index e359b04a3846..8f64f7732f77 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java @@ -201,7 +201,7 @@ static List recreateLinkActions(WasmModule module) { for (int i = 0; i < module.tagCount(); i++) { final int tagIndex = i; final int typeIndex = module.tagTypeIndex(tagIndex); - final SymbolTable.FunctionType type = module.typeAt(typeIndex); + final SymbolTable.ClosedFunctionType type = module.closedFunctionTypeAt(typeIndex); final ImportDescriptor tagDescriptor = module.importedTag(tagIndex); if (tagDescriptor != null) { linkActions.add((context, store, instance, imports) -> { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java index c5c6f90cd6aa..c376dab7fb3e 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java @@ -95,13 +95,13 @@ public final class WasmLanguage extends TruffleLanguage { private final Map builtinModules = new ConcurrentHashMap<>(); - private final Map interopCallAdapters = new ConcurrentHashMap<>(); + private final Map interopCallAdapters = new ConcurrentHashMap<>(); /** * Gets or creates the interop call adapter for a function type. Always returns the same call * target for any particular type. */ - public CallTarget interopCallAdapterFor(SymbolTable.FunctionType type) { + public CallTarget interopCallAdapterFor(SymbolTable.ClosedFunctionType type) { CompilerAsserts.neverPartOfCompilation(); CallTarget callAdapter = interopCallAdapters.get(type); if (callAdapter == null) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTag.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTag.java index 1e0ecf0afd03..f760faf8a368 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTag.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTag.java @@ -47,13 +47,13 @@ public static final class Attribute { public static final int EXCEPTION = 0; } - private final SymbolTable.FunctionType type; + private final SymbolTable.ClosedFunctionType type; - public WasmTag(SymbolTable.FunctionType type) { + public WasmTag(SymbolTable.ClosedFunctionType type) { this.type = type; } - public SymbolTable.FunctionType type() { + public SymbolTable.ClosedFunctionType type() { return type; } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java index 568a234ad6b2..2cef0d2abe41 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/FuncType.java @@ -84,18 +84,18 @@ private static ValueType[] parseTypeString(String typesString, int start, int en } } - public static FuncType fromFunctionType(SymbolTable.FunctionType functionType) { - final int[] paramTypes = functionType.paramTypes(); - final int[] resultTypes = functionType.resultTypes(); + public static FuncType fromClosedFunctionType(SymbolTable.ClosedFunctionType functionType) { + final SymbolTable.ClosedValueType[] paramTypes = functionType.paramTypes(); + final SymbolTable.ClosedValueType[] resultTypes = functionType.resultTypes(); final ValueType[] params = new ValueType[paramTypes.length]; final ValueType[] results = new ValueType[resultTypes.length]; for (int i = 0; i < paramTypes.length; i++) { - params[i] = ValueType.fromValue(paramTypes[i]); + params[i] = ValueType.fromClosedValueType(paramTypes[i]); } for (int i = 0; i < resultTypes.length; i++) { - results[i] = ValueType.fromValue(resultTypes[i]); + results[i] = ValueType.fromClosedValueType(resultTypes[i]); } return new FuncType(params, results); } @@ -116,17 +116,17 @@ public int resultCount() { return results.length; } - public SymbolTable.FunctionType toFunctionType() { - final int[] paramTypes = new int[params.length]; - final int[] resultTypes = new int[results.length]; + public SymbolTable.ClosedFunctionType toClosedFunctionType() { + var paramTypes = new SymbolTable.ClosedValueType[params.length]; + var resultTypes = new SymbolTable.ClosedValueType[results.length]; for (int i = 0; i < paramTypes.length; i++) { - paramTypes[i] = params[i].value(); + paramTypes[i] = params[i].asClosedValueType(); } for (int i = 0; i < resultTypes.length; i++) { - resultTypes[i] = results[i].value(); + resultTypes[i] = results[i].asClosedValueType(); } - return SymbolTable.FunctionType.createClosed(paramTypes, resultTypes); + return new SymbolTable.ClosedFunctionType(paramTypes, resultTypes); } public String toString(StringBuilder b) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java index 15cf1c5e8f56..9df6ee0dfafc 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java @@ -74,11 +74,11 @@ public final class InteropCallAdapterNode extends RootNode { private static final int MAX_UNROLL = 32; - private final SymbolTable.FunctionType functionType; + private final SymbolTable.ClosedFunctionType functionType; private final BranchProfile errorBranch = BranchProfile.create(); @Child private WasmIndirectCallNode callNode; - public InteropCallAdapterNode(WasmLanguage language, SymbolTable.FunctionType functionType) { + public InteropCallAdapterNode(WasmLanguage language, SymbolTable.ClosedFunctionType functionType) { super(language); this.functionType = functionType; this.callNode = WasmIndirectCallNode.create(); @@ -109,80 +109,90 @@ public Object execute(VirtualFrame frame) { } private Object[] validateArguments(Object[] arguments, int offset) throws ArityException, UnsupportedTypeException { - final int[] paramTypes = functionType.paramTypes(); + final SymbolTable.ClosedValueType[] paramTypes = functionType.paramTypes(); final int paramCount = paramTypes.length; CompilerAsserts.partialEvaluationConstant(paramCount); if (arguments.length - offset != paramCount) { throw ArityException.create(paramCount, paramCount, arguments.length - offset); } if (CompilerDirectives.inCompiledCode() && paramCount <= MAX_UNROLL) { - validateArgumentsUnroll(arguments, offset, paramTypes, paramCount, functionType.symbolTable()); + validateArgumentsUnroll(arguments, offset, paramTypes, paramCount); } else { for (int i = 0; i < paramCount; i++) { - validateArgument(arguments, offset, paramTypes, i, functionType.symbolTable()); + validateArgument(arguments, offset, paramTypes, i); } } return arguments; } @ExplodeLoop - private static void validateArgumentsUnroll(Object[] arguments, int offset, int[] paramTypes, int paramCount, SymbolTable symbolTable) throws UnsupportedTypeException { + private static void validateArgumentsUnroll(Object[] arguments, int offset, SymbolTable.ClosedValueType[] paramTypes, int paramCount) throws UnsupportedTypeException { for (int i = 0; i < paramCount; i++) { - validateArgument(arguments, offset, paramTypes, i, symbolTable); + validateArgument(arguments, offset, paramTypes, i); } } - private static void validateArgument(Object[] arguments, int offset, int[] paramTypes, int i, SymbolTable symbolTable) throws UnsupportedTypeException { - int paramType = paramTypes[i]; + private static void validateArgument(Object[] arguments, int offset, SymbolTable.ClosedValueType[] paramTypes, int i) throws UnsupportedTypeException { + SymbolTable.ClosedValueType paramType = paramTypes[i]; Object value = arguments[i + offset]; - switch (paramType) { - case WasmType.I32_TYPE -> { - if (value instanceof Integer) { - return; - } - } - case WasmType.I64_TYPE -> { - if (value instanceof Long) { - return; - } - } - case WasmType.F32_TYPE -> { - if (value instanceof Float) { - return; - } - } - case WasmType.F64_TYPE -> { - if (value instanceof Double) { - return; - } - } - case WasmType.V128_TYPE -> { - if (value instanceof Vector128) { - return; - } - } - default -> { - assert WasmType.isReferenceType(paramType); - boolean nullable = WasmType.isNullable(paramType); - switch (WasmType.getAbstractHeapType(paramType)) { - case WasmType.FUNCREF_TYPE -> { - if (value instanceof WasmFunctionInstance || nullable && value == WasmConstant.NULL) { + switch (paramType.kind()) { + case Number -> { + SymbolTable.NumberType numberType = (SymbolTable.NumberType) paramType; + switch (numberType.value()) { + case WasmType.I32_TYPE -> { + if (value instanceof Integer) { return; } } - case WasmType.EXTERNREF_TYPE -> { - if (nullable || value != WasmConstant.NULL) { + case WasmType.I64_TYPE -> { + if (value instanceof Long) { return; } } - case WasmType.EXNREF_TYPE -> { - if (value instanceof WasmRuntimeException || nullable && value == WasmConstant.NULL) { + case WasmType.F32_TYPE -> { + if (value instanceof Float) { return; } } - default -> { - assert WasmType.isConcreteReferenceType(paramType); - SymbolTable.ClosedFunctionType functionType = symbolTable.closedFunctionTypeAt(WasmType.getTypeIndex(paramType)); + case WasmType.F64_TYPE -> { + if (value instanceof Double) { + return; + } + } + } + } + case Vector -> { + if (value instanceof Vector128) { + return; + } + } + case Reference -> { + SymbolTable.ClosedReferenceType referenceType = (SymbolTable.ClosedReferenceType) paramType; + boolean nullable = referenceType.nullable(); + SymbolTable.ClosedHeapType heapType = referenceType.heapType(); + switch (heapType.kind()) { + case Abstract -> { + SymbolTable.AbstractHeapType abstractHeapType = (SymbolTable.AbstractHeapType) heapType; + switch (abstractHeapType.value()) { + case WasmType.FUNCREF_TYPE -> { + if (value instanceof WasmFunctionInstance || nullable && value == WasmConstant.NULL) { + return; + } + } + case WasmType.EXTERNREF_TYPE -> { + if (nullable || value != WasmConstant.NULL) { + return; + } + } + case WasmType.EXNREF_TYPE -> { + if (value instanceof WasmRuntimeException || nullable && value == WasmConstant.NULL) { + return; + } + } + } + } + case Function -> { + SymbolTable.ClosedFunctionType functionType = (SymbolTable.ClosedFunctionType) heapType; if (value instanceof WasmFunctionInstance instance && functionType.matches(instance.function().closedType()) || nullable && value == WasmConstant.NULL) { return; } @@ -197,7 +207,7 @@ private Object multiValueStackAsArray(WasmLanguage language) { final var multiValueStack = language.multiValueStack(); final long[] primitiveMultiValueStack = multiValueStack.primitiveStack(); final Object[] objectMultiValueStack = multiValueStack.objectStack(); - final int[] resultTypes = functionType.resultTypes(); + final SymbolTable.ClosedValueType[] resultTypes = functionType.resultTypes(); final int resultCount = resultTypes.length; assert primitiveMultiValueStack.length >= resultCount; assert objectMultiValueStack.length >= resultCount; @@ -214,21 +224,26 @@ private Object multiValueStackAsArray(WasmLanguage language) { } @ExplodeLoop - private static void popMultiValueResultUnroll(Object[] values, long[] primitiveMultiValueStack, Object[] objectMultiValueStack, int[] resultTypes, int resultCount) { + private static void popMultiValueResultUnroll(Object[] values, long[] primitiveMultiValueStack, Object[] objectMultiValueStack, SymbolTable.ClosedValueType[] resultTypes, int resultCount) { for (int i = 0; i < resultCount; i++) { values[i] = popMultiValueResult(primitiveMultiValueStack, objectMultiValueStack, resultTypes, i); } } - private static Object popMultiValueResult(long[] primitiveMultiValueStack, Object[] objectMultiValueStack, int[] resultTypes, int i) { - final int resultType = resultTypes[i]; - return switch (resultType) { - case WasmType.I32_TYPE -> (int) primitiveMultiValueStack[i]; - case WasmType.I64_TYPE -> primitiveMultiValueStack[i]; - case WasmType.F32_TYPE -> Float.intBitsToFloat((int) primitiveMultiValueStack[i]); - case WasmType.F64_TYPE -> Double.longBitsToDouble(primitiveMultiValueStack[i]); - default -> { - assert resultType == WasmType.V128_TYPE || WasmType.isReferenceType(resultType); + private static Object popMultiValueResult(long[] primitiveMultiValueStack, Object[] objectMultiValueStack, SymbolTable.ClosedValueType[] resultTypes, int i) { + final SymbolTable.ClosedValueType resultType = resultTypes[i]; + return switch (resultType.kind()) { + case Number -> { + SymbolTable.NumberType numberType = (SymbolTable.NumberType) resultType; + yield switch (numberType.value()) { + case WasmType.I32_TYPE -> (int) primitiveMultiValueStack[i]; + case WasmType.I64_TYPE -> primitiveMultiValueStack[i]; + case WasmType.F32_TYPE -> Float.intBitsToFloat((int) primitiveMultiValueStack[i]); + case WasmType.F64_TYPE -> Double.longBitsToDouble(primitiveMultiValueStack[i]); + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } + case Vector, Reference -> { Object obj = objectMultiValueStack[i]; objectMultiValueStack[i] = null; yield obj; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java index 076b1d5942e3..b5bd2f2be3e4 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java @@ -40,6 +40,7 @@ */ package org.graalvm.wasm.api; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmType; import org.graalvm.wasm.exception.Failure; import org.graalvm.wasm.exception.WasmException; @@ -53,8 +54,7 @@ public enum ValueType { f64(WasmType.F64_TYPE), v128(WasmType.V128_TYPE), anyfunc(WasmType.FUNCREF_TYPE), - externref(WasmType.EXTERNREF_TYPE), - exnref(WasmType.EXNREF_TYPE); + externref(WasmType.EXTERNREF_TYPE); private final int value; @@ -72,7 +72,6 @@ public static ValueType fromValue(int value) { case WasmType.V128_TYPE -> v128; case WasmType.FUNCREF_TYPE -> anyfunc; case WasmType.EXTERNREF_TYPE -> externref; - case WasmType.EXNREF_TYPE -> exnref; default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(value)); }; } @@ -90,6 +89,43 @@ public static boolean isVectorType(ValueType valueType) { } public static boolean isReferenceType(ValueType valueType) { - return valueType == anyfunc || valueType == externref || valueType == exnref; + return valueType == anyfunc || valueType == externref; + } + + public SymbolTable.ClosedValueType asClosedValueType() { + return switch (value) { + case WasmType.I32_TYPE -> SymbolTable.NumberType.I32; + case WasmType.I64_TYPE -> SymbolTable.NumberType.I64; + case WasmType.F32_TYPE -> SymbolTable.NumberType.F32; + case WasmType.F64_TYPE -> SymbolTable.NumberType.F64; + case WasmType.V128_TYPE -> SymbolTable.VectorType.V128; + case WasmType.FUNCREF_TYPE -> SymbolTable.ClosedReferenceType.FUNCREF; + case WasmType.EXTERNREF_TYPE -> SymbolTable.ClosedReferenceType.EXTERNREF; + default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(value)); + }; + } + + public static ValueType fromClosedValueType(SymbolTable.ClosedValueType closedValueType) { + return switch (closedValueType.kind()) { + case Number -> fromValue(((SymbolTable.NumberType) closedValueType).value()); + case Vector -> fromValue(((SymbolTable.VectorType) closedValueType).value()); + case Reference -> { + SymbolTable.ClosedReferenceType referenceType = (SymbolTable.ClosedReferenceType) closedValueType; + if (!referenceType.nullable()) { + throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: non-nullable reference"); + } + yield switch (referenceType.heapType().kind()) { + case Abstract -> { + SymbolTable.AbstractHeapType abstractHeapType = (SymbolTable.AbstractHeapType) referenceType.heapType(); + yield switch (abstractHeapType.value()) { + case WasmType.FUNC_HEAPTYPE -> anyfunc; + case WasmType.EXTERNREF_TYPE -> externref; + default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(abstractHeapType.value())); + }; + } + case Function -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: typed function reference"); + }; + } + }; } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java index 24221ff68fdd..c9b3f555705e 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java @@ -591,7 +591,7 @@ private static String tagTypeToString(WasmModule module, int tagIndex) { assert attribute == WasmTag.Attribute.EXCEPTION; final int typeIndex = module.tagTypeIndex(tagIndex); - return FuncType.fromFunctionType(module.typeAt(typeIndex)).toString(); + return FuncType.fromClosedFunctionType(module.closedFunctionTypeAt(typeIndex)).toString(); } private static Object memAlloc(Object[] args) { @@ -942,7 +942,7 @@ public static Object tagAlloc(Object[] args) { } public static WasmTag tagAlloc(FuncType type) { - return new WasmTag(type.toFunctionType()); + return new WasmTag(type.toClosedFunctionType()); } public static Object tagType(Object[] args) { @@ -958,7 +958,7 @@ public Object exnAlloc(Object[] args) { if (!(args[0] instanceof WasmTag tag)) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be a wasm tag"); } - final FuncType type = FuncType.fromFunctionType(tag.type()); + final FuncType type = FuncType.fromClosedFunctionType(tag.type()); final int paramCount = type.paramCount(); checkArgumentCount(args, paramCount + 1); final Object[] fields = new Object[paramCount]; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java index c6afa14d4d95..8f5b1a1a3edc 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java @@ -48,7 +48,6 @@ import org.graalvm.wasm.api.ValueType; import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.constants.GlobalModifier; -import org.graalvm.wasm.exception.WasmRuntimeException; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.interop.InteropLibrary; @@ -179,7 +178,7 @@ Object readMember(String member) throws UnknownIdentifierException { case f32 -> Float.intBitsToFloat(loadAsInt()); case f64 -> Double.longBitsToDouble(loadAsLong()); case v128 -> loadAsVector128(); - case anyfunc, externref, exnref -> loadAsReference(); + case anyfunc, externref -> loadAsReference(); }; } @@ -229,12 +228,6 @@ void writeMember(String member, Object value, } throw UnsupportedMessageException.create(); } - case exnref -> { - if (value == WasmConstant.NULL || value instanceof WasmRuntimeException) { - storeReference(value); - } - throw UnsupportedMessageException.create(); - } } } From 9e563adcbc32826832af6b45dce01fdb630fa8e9 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Wed, 24 Sep 2025 09:53:23 +0200 Subject: [PATCH 03/40] Use subtype matching during static type checking --- .../src/org/graalvm/wasm/BinaryParser.java | 4 ++-- .../src/org/graalvm/wasm/SymbolTable.java | 4 ++++ .../src/org/graalvm/wasm/WasmType.java | 1 + .../wasm/parser/validation/ParserState.java | 20 ++++++++++--------- .../parser/validation/ValidationErrors.java | 4 ++-- 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index 337523bc7d97..4def2e108d06 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -602,7 +602,7 @@ private static int[] encapsulateResultType(int type) { private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultTypes, int sourceCodeEndOffset, boolean hasNextFunction, RuntimeBytecodeGen bytecode, int codeEntryIndex, EconomicMap offsetToLineIndexMap) { - final ParserState state = new ParserState(bytecode); + final ParserState state = new ParserState(bytecode, module); final ArrayList callNodes = new ArrayList<>(); final int bytecodeStartOffset = bytecode.location(); state.enterFunction(resultTypes); @@ -2597,7 +2597,7 @@ private Pair readConstantExpression(int resultType, boolean only // Read the constant expression. // https://webassembly.github.io/spec/core/valid/instructions.html#constant-expressions final RuntimeBytecodeGen bytecode = new RuntimeBytecodeGen(); - final ParserState state = new ParserState(bytecode); + final ParserState state = new ParserState(bytecode, module); final List stack = new ArrayList<>(); boolean calculable = true; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index 8042bcf1f075..9a0e948db393 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -822,6 +822,10 @@ yield switch (WasmType.getAbstractHeapType(type)) { }; } + public boolean matches(int expectedType, int actualType) { + return closedTypeAt(expectedType).matches(closedTypeAt(actualType)); + } + public void importSymbol(ImportDescriptor descriptor) { checkNotParsed(); assert importedSymbols.size() == descriptor.importedSymbolIndex(); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java index 9066d1e7d358..904507d23473 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java @@ -98,6 +98,7 @@ public class WasmType implements TruffleObject { /** * Implementation-specific Types. */ + public static final int TOP = -0x7d; public static final int NULL_TYPE = -0x7e; public static final int UNKNOWN_TYPE = -0x7f; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java index b29c9f142709..59dd04da05a5 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java @@ -46,6 +46,7 @@ import java.util.ArrayList; import org.graalvm.wasm.Assert; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmType; import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.collection.IntArrayList; @@ -60,21 +61,22 @@ */ public class ParserState { private static final int[] EMPTY_ARRAY = new int[0]; - private static final int ANY = 0; private final IntArrayList valueStack; private final ControlStack controlStack; private final RuntimeBytecodeGen bytecode; private final ArrayList exceptionTables; + private final SymbolTable symbolTable; private int maxStackSize; private boolean usesMemoryZero; - public ParserState(RuntimeBytecodeGen bytecode) { + public ParserState(RuntimeBytecodeGen bytecode, SymbolTable symbolTable) { this.valueStack = new IntArrayList(); this.controlStack = new ControlStack(); this.bytecode = bytecode; this.exceptionTables = new ArrayList<>(); + this.symbolTable = symbolTable; this.maxStackSize = 0; } @@ -90,8 +92,8 @@ private int popInternal(int expectedValueType) { if (isCurrentStackUnreachable()) { return WasmType.UNKNOWN_TYPE; } else { - if (expectedValueType == ANY) { - throw ValidationErrors.createExpectedAnyOnEmptyStack(); + if (expectedValueType == WasmType.TOP) { + throw ValidationErrors.createExpectedTopOnEmptyStack(); } else { throw ValidationErrors.createExpectedTypeOnEmptyStack(expectedValueType); } @@ -101,7 +103,7 @@ private int popInternal(int expectedValueType) { } /** - * Pops the maximum available values form the current stack frame. If the number of values on + * Pops the maximum available values from the current stack frame. If the number of values on * the stack is smaller than the number of expectedValueTypes, only the remaining stack values * are returned. If the number of values on the stack is greater or equal to the number of * expectedValueTypes, the values equal to the number of expectedValueTypes is popped from the @@ -130,7 +132,7 @@ private int[] popAvailableUnchecked() { int[] popped = new int[availableStackSize]; int j = 0; for (int i = availableStackSize - 1; i >= 0; i--) { - popped[j] = popInternal(ANY); + popped[j] = popInternal(WasmType.TOP); j++; } return popped; @@ -151,7 +153,7 @@ private boolean isTypeMismatch(int[] expectedTypes, int[] actualTypes) { return false; } for (int i = 0; i < expectedTypes.length; i++) { - if (expectedTypes[i] != actualTypes[i]) { + if (!symbolTable.matches(expectedTypes[i], actualTypes[i])) { return true; } } @@ -186,7 +188,7 @@ public void pushAll(int[] valueTypes) { * @throws WasmException If the stack is empty. */ public int pop() { - return popInternal(ANY); + return popInternal(WasmType.TOP); } /** @@ -198,7 +200,7 @@ public int pop() { */ public int popChecked(int expectedValueType) { final int actualValueType = popInternal(expectedValueType); - if (actualValueType != expectedValueType && actualValueType != WasmType.UNKNOWN_TYPE && expectedValueType != WasmType.UNKNOWN_TYPE) { + if (!symbolTable.matches(expectedValueType, actualValueType) && actualValueType != WasmType.UNKNOWN_TYPE && expectedValueType != WasmType.UNKNOWN_TYPE) { throw ValidationErrors.createTypeMismatch(expectedValueType, actualValueType); } return actualValueType; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java index ad62fd3d73c9..9226b0c7c750 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java @@ -101,8 +101,8 @@ public static WasmException createMissingFunctionType(int expected, int max) { } @TruffleBoundary - public static WasmException createExpectedAnyOnEmptyStack() { - return WasmException.create(Failure.TYPE_MISMATCH, "Expected type [any], but got []."); + public static WasmException createExpectedTopOnEmptyStack() { + return WasmException.create(Failure.TYPE_MISMATCH, "Expected type [top], but got []."); } @TruffleBoundary From 0f4dfb69a009c2ab747259a0c147494a1b2f6cba Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 23 Sep 2025 10:18:39 +0200 Subject: [PATCH 04/40] Add typed function references instructions --- .../src/org/graalvm/wasm/BinaryParser.java | 163 ++++++++++++------ .../src/org/graalvm/wasm/SymbolTable.java | 32 +++- .../src/org/graalvm/wasm/WasmType.java | 14 +- .../wasm/api/InteropCallAdapterNode.java | 2 + .../src/org/graalvm/wasm/api/ValueType.java | 1 + .../org/graalvm/wasm/constants/Bytecode.java | 9 + .../graalvm/wasm/constants/Instructions.java | 5 + .../org/graalvm/wasm/exception/Failure.java | 2 + .../graalvm/wasm/nodes/WasmFunctionNode.java | 107 +++++++++++- .../wasm/parser/bytecode/BytecodeParser.java | 20 +++ .../parser/bytecode/RuntimeBytecodeGen.java | 89 +++++++++- .../wasm/parser/validation/BlockFrame.java | 10 ++ .../wasm/parser/validation/ControlFrame.java | 4 + .../wasm/parser/validation/IfFrame.java | 10 ++ .../wasm/parser/validation/LoopFrame.java | 10 ++ .../wasm/parser/validation/ParserState.java | 76 +++++--- 16 files changed, 464 insertions(+), 90 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index 4def2e108d06..b3c078008b6d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -40,6 +40,7 @@ */ package org.graalvm.wasm; +import static java.lang.Integer.compareUnsigned; import static org.graalvm.wasm.Assert.assertByteEqual; import static org.graalvm.wasm.Assert.assertIntEqual; import static org.graalvm.wasm.Assert.assertIntLessOrEqual; @@ -103,6 +104,7 @@ import org.graalvm.wasm.parser.ir.CodeEntry; import org.graalvm.wasm.parser.validation.ExceptionHandler; import org.graalvm.wasm.parser.validation.ParserState; +import org.graalvm.wasm.parser.validation.ValidationErrors; import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; @@ -477,7 +479,7 @@ private void readFunctionSection() { module.limits().checkFunctionCount(functionCount); for (int functionIndex = 0; functionIndex != functionCount; functionIndex++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); - int functionTypeIndex = readUnsignedInt32(); + int functionTypeIndex = readTypeIndex(); module.symbolTable().declareFunction(functionTypeIndex); } } @@ -642,7 +644,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType } case BLOCK_TYPE_TYPE_INDEX -> { int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); + checkFunctionTypeExists(typeIndex); blockParamTypes = extractBlockParamTypes(typeIndex); blockResultTypes = extractBlockResultTypes(typeIndex); } @@ -669,7 +671,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType } case BLOCK_TYPE_TYPE_INDEX -> { int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); + checkFunctionTypeExists(typeIndex); loopParamTypes = extractBlockParamTypes(typeIndex); loopResultTypes = extractBlockResultTypes(typeIndex); } @@ -696,7 +698,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType } case BLOCK_TYPE_TYPE_INDEX -> { int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); + checkFunctionTypeExists(typeIndex); ifParamTypes = extractBlockParamTypes(typeIndex); ifResultTypes = extractBlockResultTypes(typeIndex); } @@ -789,12 +791,11 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType break; } case Instructions.CALL_INDIRECT: { - final int expectedFunctionTypeIndex = readUnsignedInt32(); + final int expectedFunctionTypeIndex = readTypeIndex(); final int tableIndex = readTableIndex(); // Pop the function index to call state.popChecked(I32_TYPE); - state.checkFunctionTypeExists(expectedFunctionTypeIndex, module.typeCount()); - assertIntEqual(FUNCREF_TYPE, module.tableElementType(tableIndex), Failure.TYPE_MISMATCH); + Assert.assertTrue(module.matches(FUNCREF_TYPE, module.tableElementType(tableIndex)), Failure.TYPE_MISMATCH); // Pop parameters for (int i = module.functionTypeParamCount(expectedFunctionTypeIndex) - 1; i >= 0; --i) { @@ -827,8 +828,8 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType final int t1 = state.pop(); // first operand final int t2 = state.pop(); // second operand assertTrue((WasmType.isNumberType(t1) || WasmType.isVectorType(t1)) && (WasmType.isNumberType(t2) || WasmType.isVectorType(t2)), Failure.TYPE_MISMATCH); - assertTrue(t1 == t2 || t1 == WasmType.UNKNOWN_TYPE || t2 == WasmType.UNKNOWN_TYPE, Failure.TYPE_MISMATCH); - final int t = t1 == WasmType.UNKNOWN_TYPE ? t2 : t1; + assertTrue(t1 == t2 || WasmType.isBottomType(t1) || WasmType.isBottomType(t2), Failure.TYPE_MISMATCH); + final int t = WasmType.isBottomType(t1) ? t2 : t1; state.push(t); if (WasmType.isNumberType(t)) { state.addSelectInstruction(Bytecode.SELECT); @@ -870,7 +871,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType } case BLOCK_TYPE_TYPE_INDEX -> { int typeIndex = multiResult[0]; - state.checkFunctionTypeExists(typeIndex, module.typeCount()); + checkFunctionTypeExists(typeIndex); tryTableParamTypes = extractBlockParamTypes(typeIndex); tryTableResultTypes = extractBlockResultTypes(typeIndex); } @@ -1104,6 +1105,50 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType } break; } + case Instructions.CALL_REF: { + final int expectedFunctionTypeIndex = readTypeIndex(); + final int functionReferenceType = WasmType.withNullable(true, expectedFunctionTypeIndex); + state.popChecked(functionReferenceType); + // Pop parameters + final int paramCount = module.functionTypeParamCount(expectedFunctionTypeIndex); + for (int i = paramCount - 1; i >= 0; i--) { + state.popChecked(module.functionTypeParamTypeAt(expectedFunctionTypeIndex, i)); + } + // Push result values + final int resultCount = module.functionTypeResultCount(expectedFunctionTypeIndex); + if (!multiValue) { + assertIntLessOrEqual(resultCount, 1, Failure.INVALID_RESULT_ARITY); + } + for (int i = 0; i < resultCount; i++) { + state.push(module.functionTypeResultTypeAt(expectedFunctionTypeIndex, i)); + } + state.addRefCall(callNodes.size(), expectedFunctionTypeIndex); + callNodes.add(new CallNode(bytecode.location())); + break; + } + case Instructions.REF_AS_NON_NULL: { + final int referenceType = state.popReferenceTypeChecked(); + final int nonNullReferenceType = WasmType.withNullable(false, referenceType); + state.push(nonNullReferenceType); + state.addMiscFlag(); + state.addInstruction(Bytecode.REF_AS_NON_NULL); + break; + } + case Instructions.BR_ON_NULL: { + final int branchLabel = readTargetOffset(); + final int referenceType = state.popReferenceTypeChecked(); + state.addBranchOnNull(branchLabel); + final int nonNullReferenceType = WasmType.withNullable(false, referenceType); + state.push(nonNullReferenceType); + break; + } + case Instructions.BR_ON_NON_NULL: { + final int branchLabel = readTargetOffset(); + final int referenceType = state.popReferenceTypeChecked(); + final int nonNullReferenceType = WasmType.withNullable(false, referenceType); + state.addBranchOnNonNull(branchLabel, nonNullReferenceType); + break; + } default: readNumericInstructions(state, opcode); break; @@ -1581,7 +1626,7 @@ private void readNumericInstructions(ParserState state, int opcode) { final int destinationElementType = module.tableElementType(destinationTableIndex); final int sourceTableIndex = readTableIndex(); final int sourceElementType = module.tableElementType(sourceTableIndex); - assertIntEqual(sourceElementType, destinationElementType, Failure.TYPE_MISMATCH); + Assert.assertTrue(module.matches(destinationElementType, sourceElementType), Failure.TYPE_MISMATCH); state.popChecked(I32_TYPE); state.popChecked(I32_TYPE); state.popChecked(I32_TYPE); @@ -1639,8 +1684,9 @@ private void readNumericInstructions(ParserState state, int opcode) { break; case Instructions.REF_NULL: checkBulkMemoryAndRefTypesSupport(opcode); - final int type = readRefType(exceptions); - state.push(type); + final int heapType = readHeapType(exceptions); + final int nullableReferenceType = WasmType.withNullable(true, heapType); + state.push(nullableReferenceType); state.addInstruction(Bytecode.REF_NULL); break; case Instructions.REF_IS_NULL: @@ -1653,7 +1699,8 @@ private void readNumericInstructions(ParserState state, int opcode) { checkBulkMemoryAndRefTypesSupport(opcode); final int functionIndex = readDeclaredFunctionIndex(); module.checkFunctionReference(functionIndex); - state.push(FUNCREF_TYPE); + final int functionReferenceType = WasmType.withNullable(false, module.function(functionIndex).typeIndex()); + state.push(functionReferenceType); state.addInstruction(Bytecode.REF_FUNC, functionIndex); break; case Instructions.ATOMIC: @@ -2646,8 +2693,9 @@ private Pair readConstantExpression(int resultType, boolean only } case Instructions.REF_NULL: checkBulkMemoryAndRefTypesSupport(opcode); - final int type = readRefType(exceptions); - state.push(type); + final int heapType = readHeapType(exceptions); + final int nullableReferenceType = WasmType.withNullable(true, heapType); + state.push(nullableReferenceType); state.addInstruction(Bytecode.REF_NULL); if (calculable) { stack.add(WasmConstant.NULL); @@ -2657,7 +2705,8 @@ private Pair readConstantExpression(int resultType, boolean only checkBulkMemoryAndRefTypesSupport(opcode); final int functionIndex = readDeclaredFunctionIndex(); module.addFunctionReference(functionIndex); - state.push(FUNCREF_TYPE); + final int functionReferenceType = WasmType.withNullable(false, module.function(functionIndex).typeIndex()); + state.push(functionReferenceType); state.addInstruction(Bytecode.REF_FUNC, functionIndex); calculable = false; break; @@ -2741,7 +2790,7 @@ private Pair readConstantExpression(int resultType, boolean only break; } } - assertIntEqual(state.valueStackSize(), 1, Failure.TYPE_MISMATCH, "Multiple results on stack at constant expression end"); + assertIntEqual(state.valueStackSize(), 1, Failure.TYPE_MISMATCH, "Unexpected number of results on stack at constant expression end"); state.exit(multiValue); if (calculable) { return Pair.create(stack.removeLast(), null); @@ -2793,25 +2842,23 @@ private long[] readElemExpressions(int elemType) { throw WasmException.format(Failure.ILLEGAL_OPCODE, "Illegal opcode for constant expression: 0x%02X", opcode); } case Instructions.REF_NULL: - final int type = readRefType(exceptions); - if (bulkMemoryAndRefTypes && type != elemType) { - fail(Failure.TYPE_MISMATCH, "Invalid ref.null type: 0x%02X", type); - } + final int heapType = readHeapType(exceptions); + final int nullableReferenceType = WasmType.withNullable(true, heapType); + Assert.assertTrue(module.matches(elemType, nullableReferenceType), "Invalid ref.null type: 0x%02X", Failure.TYPE_MISMATCH); functionIndices[index] = ((long) NULL_TYPE << 32); break; case Instructions.REF_FUNC: - if (elemType != FUNCREF_TYPE) { - fail(Failure.TYPE_MISMATCH, "Invalid element type: 0x%02X", FUNCREF_TYPE); - } final int functionIndex = readDeclaredFunctionIndex(); module.addFunctionReference(functionIndex); + final int functionReferenceType = WasmType.withNullable(false, module.function(functionIndex).typeIndex()); + Assert.assertTrue(module.matches(elemType, functionReferenceType), "Invalid element type: 0x%02X", Failure.TYPE_MISMATCH); functionIndices[index] = ((long) FUNCREF_TYPE << 32) | functionIndex; break; case Instructions.GLOBAL_GET: final int globalIndex = readGlobalIndex(); assertIntEqual(module.globalMutability(globalIndex), GlobalModifier.CONSTANT, Failure.CONSTANT_EXPRESSION_REQUIRED); final int valueType = module.globalValueType(globalIndex); - assertIntEqual(valueType, elemType, Failure.TYPE_MISMATCH); + Assert.assertTrue(module.matches(elemType, valueType), Failure.TYPE_MISMATCH); functionIndices[index] = ((long) I32_TYPE << 32) | globalIndex; break; case Instructions.VECTOR: @@ -3267,10 +3314,22 @@ private int readDeclaredFunctionIndex() { return index; } + /** + * Checks if the given function type is within range. + * + * @param typeIndex The function type. + * @throws WasmException If the given function type is greater or equal to the given maximum. + */ + public void checkFunctionTypeExists(int typeIndex) { + if (compareUnsigned(typeIndex, module.typeCount()) >= 0) { + throw ValidationErrors.createMissingFunctionType(typeIndex, module.tableCount() - 1); + } + } + private int readTypeIndex() { - final int result = readUnsignedInt32(); - assertUnsignedIntLess(result, module.symbolTable().typeCount(), Failure.UNKNOWN_TYPE); - return result; + final int typeIndex = readUnsignedInt32(); + checkFunctionTypeExists(typeIndex); + return typeIndex; } private int readFunctionIndex() { @@ -3322,33 +3381,33 @@ private byte readImportType() { private int readRefType(boolean allowExnType) { final int refType = readSignedInt32(); - switch (refType) { - case FUNCREF_TYPE, EXTERNREF_TYPE -> { - return refType; - } + return switch (refType) { + case FUNCREF_TYPE, EXTERNREF_TYPE -> refType; case EXNREF_TYPE -> { assertTrue(allowExnType, Failure.MALFORMED_REFERENCE_TYPE); - return refType; - } - case REF_NULL_TYPE_HEADER, REF_TYPE_HEADER -> { - boolean nullable = refType == REF_NULL_TYPE_HEADER; - int heapType = readSignedInt32(); - return switch (heapType) { - case FUNC_HEAPTYPE, EXTERN_HEAPTYPE -> WasmType.withNullable(nullable, heapType); - case EXN_HEAPTYPE -> { - assertTrue(allowExnType, Failure.MALFORMED_REFERENCE_TYPE); - yield WasmType.withNullable(nullable, heapType); - } - default -> { - if (heapType < 0 || heapType > WasmType.MAX_TYPE_INDEX) { - throw fail(Failure.MALFORMED_REFERENCE_TYPE, "Unexpected reference type"); - } - yield WasmType.withNullable(nullable, heapType); - } - }; + yield refType; } + case REF_NULL_TYPE_HEADER -> WasmType.withNullable(true, readHeapType(allowExnType)); + case REF_TYPE_HEADER -> WasmType.withNullable(false, readHeapType(allowExnType)); default -> throw fail(Failure.MALFORMED_REFERENCE_TYPE, "Unexpected reference type"); - } + }; + } + + private int readHeapType(boolean allowExnType) { + int heapType = readSignedInt32(); + return switch (heapType) { + case FUNC_HEAPTYPE, EXTERN_HEAPTYPE -> heapType; + case EXN_HEAPTYPE -> { + assertTrue(allowExnType, Failure.MALFORMED_HEAP_TYPE); + yield heapType; + } + default -> { + if (heapType < 0 || heapType > WasmType.MAX_TYPE_INDEX) { + throw fail(Failure.MALFORMED_HEAP_TYPE, "Unexpected heap type"); + } + yield heapType; + } + }; } private void readTableLimits(int[] out) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index 9a0e948db393..d1cbcecce507 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -88,7 +88,8 @@ public abstract static sealed class ClosedValueType { public enum Kind { Number, Vector, - Reference + Reference, + Bottom } public abstract boolean matches(ClosedValueType valueSubType); @@ -126,7 +127,7 @@ public int value() { @Override public boolean matches(ClosedValueType valueSubType) { - return this == valueSubType; + return valueSubType == BottomType.BOTTOM || valueSubType == this; } @Override @@ -150,7 +151,7 @@ public int value() { @Override public boolean matches(ClosedValueType valueSubType) { - return this == valueSubType; + return valueSubType == BottomType.BOTTOM || valueSubType == this; } @Override @@ -185,7 +186,7 @@ public ClosedHeapType heapType() { @Override public boolean matches(ClosedValueType valueSubType) { - return valueSubType instanceof ClosedReferenceType referenceSubType && (!referenceSubType.nullable || this.nullable) && this.closedHeapType.matches(referenceSubType.closedHeapType); + return valueSubType == BottomType.BOTTOM || valueSubType instanceof ClosedReferenceType referenceSubType && (!referenceSubType.nullable || this.nullable) && this.closedHeapType.matches(referenceSubType.closedHeapType); } @Override @@ -272,6 +273,22 @@ public Kind kind() { } } + public static final class BottomType extends ClosedValueType { + public static final BottomType BOTTOM = new BottomType(); + + private BottomType() {} + + @Override + public boolean matches(ClosedValueType valueSubType) { + return valueSubType instanceof BottomType; + } + + @Override + public Kind kind() { + return Kind.Bottom; + } + } + /** * @param initialSize Lower bound on table size. * @param maximumSize Upper bound on table size. @@ -529,7 +546,7 @@ public record TagInfo(byte attribute, int typeIndex) { @CompilationFinal private int codeEntryCount; /** - * All function indices that can be references via + * All function indices that can be referenced via * {@link org.graalvm.wasm.constants.Instructions#REF_FUNC}. */ @CompilationFinal private EconomicSet functionReferences; @@ -807,6 +824,9 @@ public ClosedValueType closedTypeAt(int type) { case WasmType.F64_TYPE -> NumberType.F64; case WasmType.V128_TYPE -> VectorType.V128; default -> { + if (WasmType.isBottomType(type)) { + yield BottomType.BOTTOM; + } assert WasmType.isReferenceType(type); boolean nullable = WasmType.isNullable(type); yield switch (WasmType.getAbstractHeapType(type)) { @@ -1451,7 +1471,7 @@ public void checkElemIndex(int elemIndex) { } public void checkElemType(int elemIndex, int expectedType) { - assertIntEqual(expectedType, (int) elemInstances[elemIndex], Failure.TYPE_MISMATCH); + Assert.assertTrue(matches(expectedType, (int) elemInstances[elemIndex]), Failure.TYPE_MISMATCH); } private void ensureElemInstanceCapacity(int index) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java index 904507d23473..4bd32e073aae 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java @@ -100,7 +100,7 @@ public class WasmType implements TruffleObject { */ public static final int TOP = -0x7d; public static final int NULL_TYPE = -0x7e; - public static final int UNKNOWN_TYPE = -0x7f; + public static final int BOT = -0x7f; /** * Bytes used in the binary encoding of types. @@ -155,19 +155,23 @@ yield switch (WasmType.getAbstractHeapType(valueType)) { } public static boolean isNumberType(int type) { - return type == I32_TYPE || type == I64_TYPE || type == F32_TYPE || type == F64_TYPE || type == UNKNOWN_TYPE; + return type == I32_TYPE || type == I64_TYPE || type == F32_TYPE || type == F64_TYPE || isBottomType(type); } public static boolean isVectorType(int type) { - return type == V128_TYPE || type == UNKNOWN_TYPE; + return type == V128_TYPE || isBottomType(type); } public static boolean isReferenceType(int type) { - return isConcreteReferenceType(type) || withNullable(true, type) == FUNC_HEAPTYPE || withNullable(true, type) == EXTERN_HEAPTYPE || withNullable(true, type) == EXN_HEAPTYPE || type == UNKNOWN_TYPE; + return isConcreteReferenceType(type) || withNullable(true, type) == FUNC_HEAPTYPE || withNullable(true, type) == EXTERN_HEAPTYPE || withNullable(true, type) == EXN_HEAPTYPE || isBottomType(type); + } + + public static boolean isBottomType(int type) { + return withNullable(true, type) == BOT; } public static boolean isConcreteReferenceType(int type) { - return type >= 0; + return type >= 0 || isBottomType(type); } public static int getTypeIndex(int type) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java index 9df6ee0dfafc..04dbb68168f4 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java @@ -199,6 +199,7 @@ private static void validateArgument(Object[] arguments, int offset, SymbolTable } } } + case Bottom -> throw CompilerDirectives.shouldNotReachHere(); } throw UnsupportedTypeException.create(arguments); } @@ -248,6 +249,7 @@ yield switch (numberType.value()) { objectMultiValueStack[i] = null; yield obj; } + case Bottom -> throw CompilerDirectives.shouldNotReachHere(); }; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java index b5bd2f2be3e4..5a27fc9a6afb 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java @@ -126,6 +126,7 @@ yield switch (abstractHeapType.value()) { case Function -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: typed function reference"); }; } + case Bottom -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: bottom"); }; } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java index 50d177c05564..5095fa779a27 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java @@ -376,6 +376,15 @@ public class Bytecode { public static final int THROW = 0x1C; public static final int THROW_REF = 0x1D; + // Typed function references opcodes + public static final int CALL_REF_U8 = 0x1E; + public static final int CALL_REF_I32 = 0x1F; + public static final int REF_AS_NON_NULL = 0x20; + public static final int BR_ON_NULL_U8 = 0x21; + public static final int BR_ON_NULL_I32 = 0x22; + public static final int BR_ON_NON_NULL_U8 = 0x23; + public static final int BR_ON_NON_NULL_I32 = 0x24; + // Atomic opcodes public static final int ATOMIC_I32_LOAD = 0x00; public static final int ATOMIC_I64_LOAD = 0x01; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java index e0a7b68bad5b..2bf625b5fd61 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java @@ -277,6 +277,11 @@ public final class Instructions { public static final int TABLE_SIZE = 16; public static final int TABLE_FILL = 17; + public static final int CALL_REF = 0x14; + public static final int REF_AS_NON_NULL = 0xD4; + public static final int BR_ON_NULL = 0xD5; + public static final int BR_ON_NON_NULL = 0xD6; + public static final int ATOMIC = 0xFE; public static final int ATOMIC_NOTIFY = 0x00; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java index 11d25cc18f8f..be435258ea12 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java @@ -62,6 +62,7 @@ public enum Failure { DATA_COUNT_SECTION_REQUIRED(Type.MALFORMED, "data count section required"), ILLEGAL_OPCODE(Type.MALFORMED, "illegal opcode"), MALFORMED_REFERENCE_TYPE(Type.MALFORMED, "malformed reference type"), + MALFORMED_HEAP_TYPE(Type.MALFORMED, "malformed heap type"), MALFORMED_IMPORT_KIND(Type.MALFORMED, "malformed import kind"), END_OPCODE_EXPECTED(Type.MALFORMED, "END opcode expected"), UNEXPECTED_CONTENT_AFTER_LAST_SECTION(Type.MALFORMED, "unexpected content after last section"), @@ -145,6 +146,7 @@ public enum Failure { INVALID_TYPE_IN_MULTI_VALUE(Type.TRAP, "type of value in multi-value does not match the function type"), NULL_REFERENCE(Type.TRAP, "defined element is ref.null"), + NULL_FUNCTION_REFERENCE(Type.TRAP, "null function reference"), OUT_OF_BOUNDS_TABLE_ACCESS(Type.TRAP, "out of bounds table access"), // GraalWasm-specific: TABLE_INSTANCE_SIZE_LIMIT_EXCEEDED(Type.TRAP, "table instance size exceeds limit"), diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index a5af133da711..22ad8abdebfc 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -586,7 +586,7 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i assert functionInstanceContext == WasmContext.get(this); // Validate that the target function type matches the expected type of the - // indirect call by performing an equivalence-class check. + // indirect call. if (!symtab.closedTypeAt(expectedFunctionTypeIndex).matches(new SymbolTable.ClosedReferenceType(false, function.closedType()))) { enterErrorBranch(); failFunctionTypeCheck(function, expectedFunctionTypeIndex); @@ -1495,7 +1495,7 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i final Object refType = popReference(frame, stackPointer - 1); pushInt(frame, stackPointer - 1, refType == WasmConstant.NULL ? 1 : 0); break; - case Bytecode.REF_FUNC: + case Bytecode.REF_FUNC: { final int functionIndex = rawPeekI32(bytecode, offset); final WasmFunction function = module.symbolTable().function(functionIndex); final WasmFunctionInstance functionInstance = instance.functionInstance(function); @@ -1503,6 +1503,7 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i stackPointer++; offset += 4; break; + } case Bytecode.TABLE_GET: { final int tableIndex = rawPeekI32(bytecode, offset); table_get(instance, frame, stackPointer, tableIndex); @@ -1678,6 +1679,108 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i assert exception instanceof WasmRuntimeException : "Only wasm exceptions can be thrown by throw_ref"; throw (WasmRuntimeException) exception; } + case Bytecode.CALL_REF_U8: + case Bytecode.CALL_REF_I32: { + final int callNodeIndex; + final int expectedFunctionTypeIndex; + if (miscOpcode == Bytecode.CALL_REF_U8) { + callNodeIndex = rawPeekU8(bytecode, offset); + expectedFunctionTypeIndex = rawPeekU8(bytecode, offset + 1); + offset += 2; + } else { + callNodeIndex = rawPeekI32(bytecode, offset); + expectedFunctionTypeIndex = rawPeekI32(bytecode, offset + 4); + offset += 8; + } + + // Extract the function object. + final WasmFunctionInstance functionInstance; + final CallTarget target; + final WasmContext functionInstanceContext; + final Object functionOrNull = popReference(frame, --stackPointer); + if (functionOrNull == WasmConstant.NULL) { + enterErrorBranch(); + throw WasmException.format(Failure.NULL_FUNCTION_REFERENCE, this, "Function reference is null"); + } else if (functionOrNull instanceof WasmFunctionInstance) { + functionInstance = (WasmFunctionInstance) functionOrNull; + target = functionInstance.target(); + functionInstanceContext = functionInstance.context(); + } else { + enterErrorBranch(); + throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown function object: %s", functionOrNull); + } + + // Target function instance must be from the same context. + assert functionInstanceContext == WasmContext.get(this); + + // Invoke the resolved function. + int paramCount = module.symbolTable().functionTypeParamCount(expectedFunctionTypeIndex); + Object[] args = createArgumentsForCall(frame, expectedFunctionTypeIndex, paramCount, stackPointer); + stackPointer -= paramCount; + WasmArguments.setModuleInstance(args, functionInstance.moduleInstance()); + + final Object result = executeIndirectCallNode(callNodeIndex, target, args); + stackPointer = pushIndirectCallResult(frame, stackPointer, expectedFunctionTypeIndex, result, WasmLanguage.get(this)); + CompilerAsserts.partialEvaluationConstant(stackPointer); + break; + } + case Bytecode.REF_AS_NON_NULL: { + Object reference = popReference(frame, stackPointer - 1); + if (reference == WasmConstant.NULL) { + enterErrorBranch(); + throw WasmException.format(Failure.NULL_REFERENCE, this, "Function reference is null"); + } + pushReference(frame, stackPointer - 1, reference); + break; + } + case Bytecode.BR_ON_NULL_U8: { + Object reference = popReference(frame, --stackPointer); + if (profileCondition(bytecode, offset + 1, reference == WasmConstant.NULL)) { + final int offsetDelta = rawPeekU8(bytecode, offset); + // BR_ON_NULL_U8 encodes the back jump value as a positive byte value. BR_ON_NULL_U8 + // can never perform a forward jump. + offset -= offsetDelta; + } else { + offset += 3; + pushReference(frame, stackPointer++, reference); + } + break; + } + case Bytecode.BR_ON_NULL_I32: { + Object reference = popReference(frame, --stackPointer); + if (profileCondition(bytecode, offset + 4, reference == WasmConstant.NULL)) { + final int offsetDelta = rawPeekI32(bytecode, offset); + offset += offsetDelta; + } else { + offset += 6; + pushReference(frame, stackPointer++, reference); + } + break; + } + case Bytecode.BR_ON_NON_NULL_U8: { + Object reference = popReference(frame, --stackPointer); + if (profileCondition(bytecode, offset + 1, reference != WasmConstant.NULL)) { + final int offsetDelta = rawPeekU8(bytecode, offset); + // BR_ON_NULL_U8 encodes the back jump value as a positive byte value. BR_ON_NULL_U8 + // can never perform a forward jump. + offset -= offsetDelta; + pushReference(frame, stackPointer++, reference); + } else { + offset += 3; + } + break; + } + case Bytecode.BR_ON_NON_NULL_I32: { + Object reference = popReference(frame, --stackPointer); + if (profileCondition(bytecode, offset + 4, reference != WasmConstant.NULL)) { + final int offsetDelta = rawPeekI32(bytecode, offset); + offset += offsetDelta; + pushReference(frame, stackPointer++, reference); + } else { + offset += 6; + } + break; + } default: throw CompilerDirectives.shouldNotReachHere(); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java index fbe5115ce671..80dc61b2e2fc 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java @@ -790,6 +790,16 @@ private static List readCallNodes(byte[] bytecode, int startOffset, in int miscOpcode = rawPeekU8(bytecode, offset); offset++; switch (miscOpcode) { + case Bytecode.CALL_REF_U8: { + callNodes.add(new CallNode(originalOffset)); + offset += 3; + break; + } + case Bytecode.CALL_REF_I32: { + callNodes.add(new CallNode(originalOffset)); + offset += 12; + break; + } case Bytecode.I32_TRUNC_SAT_F32_S: case Bytecode.I32_TRUNC_SAT_F32_U: case Bytecode.I32_TRUNC_SAT_F64_S: @@ -801,6 +811,11 @@ private static List readCallNodes(byte[] bytecode, int startOffset, in case Bytecode.THROW_REF: { break; } + case Bytecode.BR_ON_NULL_U8: + case Bytecode.BR_ON_NON_NULL_U8: { + offset += 3; + break; + } case Bytecode.MEMORY_FILL: case Bytecode.MEMORY64_FILL: case Bytecode.MEMORY64_SIZE: @@ -815,6 +830,11 @@ private static List readCallNodes(byte[] bytecode, int startOffset, in offset += 4; break; } + case Bytecode.BR_ON_NULL_I32: + case Bytecode.BR_ON_NON_NULL_I32: { + offset += 6; + break; + } case Bytecode.MEMORY_INIT: case Bytecode.MEMORY64_INIT: case Bytecode.MEMORY_COPY: diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java index 36b17b0cb132..e63b7e6e9845 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java @@ -444,6 +444,68 @@ public int addBranchIfLocation() { return location; } + public void addBranchOnNull(int location) { + assert location >= 0; + final int relativeOffset = location - (location() + 1); + if (relativeOffset <= 0 && relativeOffset >= -255) { + add1(Bytecode.MISC); + add1(Bytecode.BR_ON_NULL_U8); + // target + add1(-relativeOffset); + // profile + addProfile(); + } else { + add1(Bytecode.MISC); + add1(Bytecode.BR_ON_NULL_I32); + // target + add4(relativeOffset); + // profile + addProfile(); + } + } + + public int addBranchOnNullLocation() { + add1(Bytecode.MISC); + add1(Bytecode.BR_ON_NULL_I32); + final int location = location(); + // target + add4(0); + // profile + addProfile(); + return location; + } + + public void addBranchOnNonNull(int location) { + assert location >= 0; + final int relativeOffset = location - (location() + 1); + if (relativeOffset <= 0 && relativeOffset >= -255) { + add1(Bytecode.MISC); + add1(Bytecode.BR_ON_NON_NULL_U8); + // target + add1(-relativeOffset); + // profile + addProfile(); + } else { + add1(Bytecode.MISC); + add1(Bytecode.BR_ON_NON_NULL_I32); + // target + add4(relativeOffset); + // profile + addProfile(); + } + } + + public int addBranchOnNonNullLocation() { + add1(Bytecode.MISC); + add1(Bytecode.BR_ON_NON_NULL_I32); + final int location = location(); + // target + add4(0); + // profile + addProfile(); + return location; + } + /** * Adds a branch table opcode to the bytecode. If the size fits into an u8 value, a br_table_u8 * and u8 size are added. Otherwise, a br_table_i32 and i32 size are added. In both cases, a @@ -539,9 +601,8 @@ public void addCall(int nodeIndex, int functionIndex) { /** * Adds an indirect call instruction to the bytecode. If the nodeIndex, typeIndex, and * tableIndex all fit into a u8 value, a call_indirect_u8 and three u8 values are added. - * Otherwise, a call_indirect_i32 and three i32 values are added. In both cases, a 2-byte - * profile is added. - * + * Otherwise, a call_indirect_i32 and three i32 values are added. + * * @param nodeIndex The node index of the indirect call * @param typeIndex The type index of the indirect call * @param tableIndex The table index of the indirect call @@ -560,6 +621,28 @@ public void addIndirectCall(int nodeIndex, int typeIndex, int tableIndex) { } } + /** + * Adds a reference call instruction to the bytecode. If the nodeIndex and typeIndex both fit + * into a u8 value, a call_ref_u8 and two u8 values are added. Otherwise, a call_ref_i32 and + * two i32 values are added. + * + * @param nodeIndex The node index of the reference call + * @param typeIndex The type index of the reference call + */ + public void addRefCall(int nodeIndex, int typeIndex) { + if (fitsIntoUnsignedByte(nodeIndex) && fitsIntoUnsignedByte(typeIndex)) { + add1(Bytecode.MISC); + add1(Bytecode.CALL_REF_U8); + add1(nodeIndex); + add1(typeIndex); + } else { + add1(Bytecode.MISC); + add1(Bytecode.CALL_REF_I32); + add4(nodeIndex); + add4(typeIndex); + } + } + public void addSelect(int instruction) { add1(instruction); addProfile(); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java index b2848fc9c212..564cbd44ffb0 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java @@ -95,6 +95,16 @@ void addBranchIf(RuntimeBytecodeGen bytecode) { branches.add(bytecode.addBranchIfLocation()); } + @Override + void addBranchOnNull(RuntimeBytecodeGen bytecode) { + branches.add(bytecode.addBranchOnNullLocation()); + } + + @Override + void addBranchOnNonNull(RuntimeBytecodeGen bytecode) { + branches.add(bytecode.addBranchOnNonNullLocation()); + } + @Override void addBranchTableItem(RuntimeBytecodeGen bytecode) { branches.add(bytecode.addBranchTableItemLocation()); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java index 6d545b4fb961..4f1e2015cb2c 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java @@ -145,6 +145,10 @@ protected void resetUnreachable() { abstract void addBranchIf(RuntimeBytecodeGen bytecode); + abstract void addBranchOnNull(RuntimeBytecodeGen bytecode); + + abstract void addBranchOnNonNull(RuntimeBytecodeGen bytecode); + /** * Adds a branch table item targeting this control frame. Automatically patches the branch * target as soon as it is available. diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java index 01968552042f..066191a70f65 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java @@ -112,6 +112,16 @@ void addBranchIf(RuntimeBytecodeGen bytecode) { branchTargets.add(bytecode.addBranchIfLocation()); } + @Override + void addBranchOnNull(RuntimeBytecodeGen bytecode) { + branchTargets.add(bytecode.addBranchOnNullLocation()); + } + + @Override + void addBranchOnNonNull(RuntimeBytecodeGen bytecode) { + branchTargets.add(bytecode.addBranchOnNonNullLocation()); + } + @Override void addBranchTableItem(RuntimeBytecodeGen bytecode) { branchTargets.add(bytecode.addBranchTableItemLocation()); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java index 9ce7ecb83a4e..dd970b1c4572 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java @@ -80,6 +80,16 @@ void addBranchIf(RuntimeBytecodeGen bytecode) { bytecode.addBranchIf(labelLocation); } + @Override + void addBranchOnNull(RuntimeBytecodeGen bytecode) { + bytecode.addBranchOnNull(labelLocation); + } + + @Override + void addBranchOnNonNull(RuntimeBytecodeGen bytecode) { + bytecode.addBranchOnNonNull(labelLocation); + } + @Override void addBranchTableItem(RuntimeBytecodeGen bytecode) { bytecode.patchLocation(bytecode.addBranchTableItemLocation(), labelLocation); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java index 59dd04da05a5..d5ad72097828 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java @@ -90,7 +90,7 @@ public ParserState(RuntimeBytecodeGen bytecode, SymbolTable symbolTable) { private int popInternal(int expectedValueType) { if (availableStackSize() == 0) { if (isCurrentStackUnreachable()) { - return WasmType.UNKNOWN_TYPE; + return WasmType.BOT; } else { if (expectedValueType == WasmType.TOP) { throw ValidationErrors.createExpectedTopOnEmptyStack(); @@ -200,7 +200,7 @@ public int pop() { */ public int popChecked(int expectedValueType) { final int actualValueType = popInternal(expectedValueType); - if (!symbolTable.matches(expectedValueType, actualValueType) && actualValueType != WasmType.UNKNOWN_TYPE && expectedValueType != WasmType.UNKNOWN_TYPE) { + if (!WasmType.isBottomType(actualValueType) && !WasmType.isBottomType(expectedValueType) && !symbolTable.matches(expectedValueType, actualValueType)) { throw ValidationErrors.createTypeMismatch(expectedValueType, actualValueType); } return actualValueType; @@ -210,18 +210,19 @@ public int popChecked(int expectedValueType) { * Pops the topmost value type from the stack and checks if it is a reference type. * * @throws WasmException If the stack is empty or the value type is not a reference type. + * @return The reference type on top of the stack. */ - public void popReferenceTypeChecked() { + public int popReferenceTypeChecked() { if (availableStackSize() != 0) { final int value = valueStack.popBack(); if (WasmType.isReferenceType(value)) { - return; + return value; } // Push value back onto the stack and perform a checked pop to get the correct error // message valueStack.push(value); } - popChecked(WasmType.FUNCREF_TYPE); + return popChecked(WasmType.FUNCREF_TYPE); } /** @@ -418,6 +419,34 @@ public void addUnconditionalBranch(int branchLabel) { frame.addBranch(bytecode); } + public void addBranchOnNull(int branchLabel) { + checkLabelExists(branchLabel); + ControlFrame frame = getFrame(branchLabel); + final int[] labelTypes = frame.labelTypes(); + popAll(labelTypes); + pushAll(labelTypes); + frame.addBranchOnNull(bytecode); + } + + public void addBranchOnNonNull(int branchLabel, int referenceType) { + checkLabelExists(branchLabel); + ControlFrame frame = getFrame(branchLabel); + final int[] labelTypes = frame.labelTypes(); + if (labelTypes.length < 1) { + throw ValidationErrors.createLabelTypesMismatch(labelTypes, new int[]{referenceType}); + } + if (!symbolTable.matches(labelTypes[labelTypes.length - 1], referenceType)) { + throw ValidationErrors.createTypeMismatch(labelTypes[labelTypes.length - 1], referenceType); + } + for (int i = labelTypes.length - 2; i >= 0; i--) { + popChecked(labelTypes[i]); + } + for (int i = 0; i < labelTypes.length - 1; i++) { + push(labelTypes[i]); + } + frame.addBranchOnNonNull(bytecode); + } + /** * Performs the necessary branch checks and adds the branch table information to the extra data * array. @@ -457,18 +486,34 @@ public void addReturn(boolean multiValue) { } /** - * Adds the index of an indirect call node to the extra data array. + * Adds a reference call instruction to the bytecode, along with its immediate argument and the + * call node index. + * + * @param nodeIndex The index of the call node associated with this call instruction. + * @param typeIndex The index of the defined function type. + */ + public void addRefCall(int nodeIndex, int typeIndex) { + bytecode.addRefCall(nodeIndex, typeIndex); + } + + /** + * Adds an indirect call instruction to the bytecode, along with its immediate arguments and the + * call node index. * - * @param nodeIndex The index of the indirect call. + * @param nodeIndex The index of the call node associated with this call instruction. + * @param typeIndex The index of the defined function type. + * @param tableIndex The index of the table in which the function will be looked up. */ public void addIndirectCall(int nodeIndex, int typeIndex, int tableIndex) { bytecode.addIndirectCall(nodeIndex, typeIndex, tableIndex); } /** - * Adds the index of a direct call node to the extra data array. + * Adds a direct call instruction to the bytecode, along with its immediate argument and the + * call node index. * - * @param nodeIndex The index of the direct call. + * @param nodeIndex The index of the call node associated with this call instruction. + * @param functionIndex The index of the defined function. */ public void addCall(int nodeIndex, int functionIndex) { bytecode.addCall(nodeIndex, functionIndex); @@ -743,19 +788,6 @@ public void checkLabelTypes(int[] expectedTypes, int[] actualTypes) { } } - /** - * Checks if the given function type is within range. - * - * @param typeIndex The function type. - * @param max The number of available function types. - * @throws WasmException If the given function type is greater or equal to the given maximum. - */ - public void checkFunctionTypeExists(int typeIndex, int max) { - if (compareUnsigned(typeIndex, max) >= 0) { - throw ValidationErrors.createMissingFunctionType(typeIndex, max - 1); - } - } - /** * Sets the current control stack unreachable. */ From a2914f4a54604d7bb1becf59ccb9da2a770388a0 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Mon, 29 Sep 2025 16:54:56 +0200 Subject: [PATCH 05/40] Check that non-defaultable locals are initialized before access --- .../src/org/graalvm/wasm/BinaryParser.java | 9 +++++-- .../src/org/graalvm/wasm/WasmType.java | 4 +++ .../org/graalvm/wasm/exception/Failure.java | 1 + .../wasm/parser/validation/BlockFrame.java | 22 ++++++++++++++-- .../wasm/parser/validation/ControlFrame.java | 17 ++++++++++--- .../wasm/parser/validation/IfFrame.java | 12 ++++++--- .../wasm/parser/validation/LoopFrame.java | 6 +++-- .../wasm/parser/validation/ParserState.java | 25 ++++++++++++------- .../wasm/parser/validation/TryTableFrame.java | 4 +-- 9 files changed, 76 insertions(+), 24 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index b3c078008b6d..f3c921f10dd4 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -118,6 +118,8 @@ public class BinaryParser extends BinaryStreamParser { private static final int MAGIC = 0x6d736100; private static final int VERSION = 0x00000001; + private static final int[] EMPTY_LOCALS = new int[0]; + private final WasmModule module; private final WasmContext wasmContext; private final int[] multiResult; @@ -607,7 +609,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType final ParserState state = new ParserState(bytecode, module); final ArrayList callNodes = new ArrayList<>(); final int bytecodeStartOffset = bytecode.location(); - state.enterFunction(resultTypes); + state.enterFunction(resultTypes, locals); int opcode; end: while (offset < sourceCodeEndOffset) { @@ -906,6 +908,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType case Instructions.LOCAL_GET: { final int localIndex = readLocalIndex(); assertUnsignedIntLess(localIndex, locals.length, Failure.UNKNOWN_LOCAL); + Assert.assertTrue(state.isLocalInitialized(localIndex), Failure.UNINITIALIZED_LOCAL); final int localType = locals[localIndex]; state.push(localType); if (WasmType.isNumberType(localType)) { @@ -918,6 +921,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType case Instructions.LOCAL_SET: { final int localIndex = readLocalIndex(); assertUnsignedIntLess(localIndex, locals.length, Failure.UNKNOWN_LOCAL); + state.initializeLocal(localIndex); final int localType = locals[localIndex]; state.popChecked(localType); if (WasmType.isNumberType(localType)) { @@ -930,6 +934,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType case Instructions.LOCAL_TEE: { final int localIndex = readLocalIndex(); assertUnsignedIntLess(localIndex, locals.length, Failure.UNKNOWN_LOCAL); + state.initializeLocal(localIndex); final int localType = locals[localIndex]; state.popChecked(localType); state.push(localType); @@ -2649,7 +2654,7 @@ private Pair readConstantExpression(int resultType, boolean only final List stack = new ArrayList<>(); boolean calculable = true; - state.enterFunction(new int[]{resultType}); + state.enterFunction(new int[]{resultType}, EMPTY_LOCALS); int opcode; while ((opcode = read1() & 0xFF) != Instructions.END) { switch (opcode) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java index 4bd32e073aae..ad0bcc9f204a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java @@ -193,6 +193,10 @@ public static int withNullable(boolean nullable, int type) { return nullable ? type | TYPE_NULLABLE_MASK : type & ~TYPE_NULLABLE_MASK; } + public static boolean hasDefaultValue(int type) { + return !(isReferenceType(type) && !isNullable(type)); + } + public static int getCommonValueType(int[] types) { int type = 0; for (int resultType : types) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java index be435258ea12..ee6ddbab11fc 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java @@ -82,6 +82,7 @@ public enum Failure { MULTIPLE_TABLES(Type.INVALID, "multiple tables"), LOOP_INPUT(Type.INVALID, "non-empty loop input type"), UNKNOWN_LOCAL(Type.INVALID, "unknown local"), + UNINITIALIZED_LOCAL(Type.INVALID, "uninitialized local"), UNKNOWN_GLOBAL(Type.INVALID, "unknown global"), UNKNOWN_MEMORY(Type.INVALID, "unknown memory"), UNKNOWN_TABLE(Type.INVALID, "unknown table"), diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java index 564cbd44ffb0..1da77891fd6d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java @@ -42,7 +42,9 @@ package org.graalvm.wasm.parser.validation; import java.util.ArrayList; +import java.util.BitSet; +import org.graalvm.wasm.WasmType; import org.graalvm.wasm.collection.IntArrayList; import org.graalvm.wasm.exception.Failure; import org.graalvm.wasm.exception.WasmException; @@ -52,15 +54,31 @@ * Representation of a wasm block during module validation. */ class BlockFrame extends ControlFrame { + private static final int[] EMPTY_ARRAY = new int[0]; + private final IntArrayList branches; private final ArrayList exceptionHandlers; - BlockFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, boolean unreachable) { - super(paramTypes, resultTypes, initialStackSize, unreachable); + private BlockFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, BitSet initializedLocals) { + super(paramTypes, resultTypes, initialStackSize, initializedLocals); branches = new IntArrayList(); exceptionHandlers = new ArrayList<>(); } + BlockFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, ControlFrame parentFrame) { + this(paramTypes, resultTypes, initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); + } + + static BlockFrame createFunctionFrame(int[] resultTypes, int[] locals) { + BitSet initializedLocals = new BitSet(locals.length); + for (int localIndex = 0; localIndex < locals.length; localIndex++) { + if (WasmType.hasDefaultValue(locals[localIndex])) { + initializedLocals.set(localIndex); + } + } + return new BlockFrame(EMPTY_ARRAY, resultTypes, 0, initializedLocals); + } + @Override int[] labelTypes() { return resultTypes(); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java index 4f1e2015cb2c..f6cc06575800 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java @@ -44,6 +44,8 @@ import org.graalvm.wasm.WasmType; import org.graalvm.wasm.parser.bytecode.RuntimeBytecodeGen; +import java.util.BitSet; + /** * Represents the scope of a block structure during module validation. */ @@ -54,19 +56,20 @@ public abstract class ControlFrame { private final int initialStackSize; private boolean unreachable; private final int commonResultType; + protected BitSet initializedLocals; /** * @param paramTypes The parameter value types of the block structure. * @param resultTypes The result value types of the block structure. * @param initialStackSize The size of the value stack when entering this block structure. - * @param unreachable If the block structure should be declared unreachable. */ - ControlFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, boolean unreachable) { + ControlFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, BitSet initializedLocals) { this.paramTypes = paramTypes; this.resultTypes = resultTypes; this.initialStackSize = initialStackSize; - this.unreachable = unreachable; + this.unreachable = false; commonResultType = WasmType.getCommonValueType(resultTypes); + this.initializedLocals = (BitSet) initializedLocals.clone(); } protected int[] paramTypes() { @@ -113,6 +116,14 @@ protected void resetUnreachable() { this.unreachable = false; } + boolean isLocalInitialized(int localIndex) { + return initializedLocals.get(localIndex); + } + + void initializeLocal(int localIndex) { + initializedLocals.set(localIndex); + } + /** * Performs checks and actions when entering an else branch. * diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java index 066191a70f65..672d4ea05678 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java @@ -43,6 +43,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.BitSet; import org.graalvm.wasm.collection.IntArrayList; import org.graalvm.wasm.exception.Failure; @@ -56,13 +57,15 @@ class IfFrame extends ControlFrame { private final IntArrayList branchTargets; private final ArrayList exceptionHandlers; + private final ControlFrame parentFrame; private int falseJumpLocation; private boolean elseBranch; - IfFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, boolean unreachable, int falseJumpLocation) { - super(paramTypes, resultTypes, initialStackSize, unreachable); - branchTargets = new IntArrayList(); - exceptionHandlers = new ArrayList<>(); + IfFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, ControlFrame parentFrame, int falseJumpLocation) { + super(paramTypes, resultTypes, initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); + this.branchTargets = new IntArrayList(); + this.exceptionHandlers = new ArrayList<>(); + this.parentFrame = parentFrame; this.falseJumpLocation = falseJumpLocation; this.elseBranch = false; } @@ -74,6 +77,7 @@ int[] labelTypes() { @Override void enterElse(ParserState state, RuntimeBytecodeGen bytecode) { + initializedLocals = (BitSet) parentFrame.initializedLocals.clone(); final int location = bytecode.addBranchLocation(); bytecode.patchLocation(falseJumpLocation, bytecode.location()); falseJumpLocation = location; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java index dd970b1c4572..2596f4a2da83 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java @@ -45,14 +45,16 @@ import org.graalvm.wasm.exception.WasmException; import org.graalvm.wasm.parser.bytecode.RuntimeBytecodeGen; +import java.util.BitSet; + /** * Representation of a wasm loop during module validation. */ class LoopFrame extends ControlFrame { private final int labelLocation; - LoopFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, boolean unreachable, int labelLocation) { - super(paramTypes, resultTypes, initialStackSize, unreachable); + LoopFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, ControlFrame parentFrame, int labelLocation) { + super(paramTypes, resultTypes, initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); this.labelLocation = labelLocation; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java index d5ad72097828..cc34209feb4d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java @@ -60,8 +60,6 @@ * additional information used to generate parser nodes. */ public class ParserState { - private static final int[] EMPTY_ARRAY = new int[0]; - private final IntArrayList valueStack; private final ControlStack controlStack; private final RuntimeBytecodeGen bytecode; @@ -267,8 +265,9 @@ private void unwindStack(int size) { } } - public void enterFunction(int[] resultTypes) { - enterBlock(EMPTY_ARRAY, resultTypes); + public void enterFunction(int[] resultTypes, int[] locals) { + ControlFrame frame = BlockFrame.createFunctionFrame(resultTypes, locals); + controlStack.push(frame); } /** @@ -279,7 +278,7 @@ public void enterFunction(int[] resultTypes) { * @param resultTypes The result types of the block that was entered. */ public void enterBlock(int[] paramTypes, int[] resultTypes) { - ControlFrame frame = new BlockFrame(paramTypes, resultTypes, valueStack.size(), false); + ControlFrame frame = new BlockFrame(paramTypes, resultTypes, valueStack.size(), controlStack.peek()); controlStack.push(frame); pushAll(paramTypes); } @@ -293,7 +292,7 @@ public void enterBlock(int[] paramTypes, int[] resultTypes) { */ public void enterLoop(int[] paramTypes, int[] resultTypes) { final int label = bytecode.addLoopLabel(paramTypes.length, valueStack.size(), WasmType.getCommonValueType(resultTypes)); - ControlFrame frame = new LoopFrame(paramTypes, resultTypes, valueStack.size(), false, label); + ControlFrame frame = new LoopFrame(paramTypes, resultTypes, valueStack.size(), controlStack.peek(), label); controlStack.push(frame); pushAll(paramTypes); } @@ -307,7 +306,7 @@ public void enterLoop(int[] paramTypes, int[] resultTypes) { */ public void enterIf(int[] paramTypes, int[] resultTypes) { final int fixupLocation = bytecode.addIfLocation(); - ControlFrame frame = new IfFrame(paramTypes, resultTypes, valueStack.size(), false, fixupLocation); + ControlFrame frame = new IfFrame(paramTypes, resultTypes, valueStack.size(), controlStack.peek(), fixupLocation); controlStack.push(frame); pushAll(paramTypes); } @@ -330,7 +329,7 @@ public void enterElse() { * @param handlers The exception handlers of the try table that was entered. */ public void enterTryTable(int[] paramTypes, int[] resultTypes, ExceptionHandler[] handlers) { - final TryTableFrame frame = new TryTableFrame(paramTypes, resultTypes, valueStack.size(), false, bytecode.location(), handlers); + final TryTableFrame frame = new TryTableFrame(paramTypes, resultTypes, valueStack.size(), controlStack.peek(), bytecode.location(), handlers); controlStack.push(frame); exceptionTables.add(frame.table()); @@ -350,7 +349,7 @@ public ExceptionHandler enterCatchClause(int opcode, int tag, int label) { checkLabelExists(label); final ControlFrame labelFrame = getFrame(label); // we reuse the block frame, instead of introducing a new catch frame. - final ControlFrame frame = new BlockFrame(WasmType.VOID_TYPE_ARRAY, labelFrame.labelTypes(), labelFrame.initialStackSize(), false); + final ControlFrame frame = new BlockFrame(WasmType.VOID_TYPE_ARRAY, labelFrame.labelTypes(), labelFrame.initialStackSize(), controlStack.peek()); controlStack.push(frame); final ExceptionHandler e = new ExceptionHandler(opcode, tag); labelFrame.addExceptionHandler(e); @@ -725,6 +724,14 @@ public ControlFrame getRootBlock() { return controlStack.getFirst(); } + public boolean isLocalInitialized(int localIndex) { + return controlStack.peek().isLocalInitialized(localIndex); + } + + public void initializeLocal(int localIndex) { + controlStack.peek().initializeLocal(localIndex); + } + /** * Checks if the return value types of the given control frame match the remaining value types * on the stack. diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java index 90c34e5cd050..ddfe213a700a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/TryTableFrame.java @@ -49,8 +49,8 @@ public class TryTableFrame extends BlockFrame { private final ExceptionTable table; - TryTableFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, boolean unreachable, int startOffset, ExceptionHandler[] handlers) { - super(paramTypes, resultTypes, initialStackSize, unreachable); + TryTableFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, ControlFrame parentFrame, int startOffset, ExceptionHandler[] handlers) { + super(paramTypes, resultTypes, initialStackSize, parentFrame); this.table = new ExceptionTable(startOffset, handlers); } From 065cd0b0011be09d9180f065511a677ab662c3a1 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 30 Sep 2025 11:46:37 +0200 Subject: [PATCH 06/40] Support initializer expressions for tables --- .../src/org/graalvm/wasm/BinaryParser.java | 70 ++++++++++++------- .../src/org/graalvm/wasm/Linker.java | 60 ++++++++++++++-- .../src/org/graalvm/wasm/SymbolTable.java | 34 ++++++--- .../org/graalvm/wasm/WasmInstantiator.java | 4 ++ .../wasm/constants/BytecodeBitEncoding.java | 4 ++ .../org/graalvm/wasm/exception/Failure.java | 1 + .../wasm/predefined/BuiltinModule.java | 4 +- 7 files changed, 133 insertions(+), 44 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index f3c921f10dd4..6c41a1944e28 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -59,12 +59,14 @@ import static org.graalvm.wasm.WasmType.FUNC_HEAPTYPE; import static org.graalvm.wasm.WasmType.I32_TYPE; import static org.graalvm.wasm.WasmType.I64_TYPE; -import static org.graalvm.wasm.WasmType.NULL_TYPE; import static org.graalvm.wasm.WasmType.REF_NULL_TYPE_HEADER; import static org.graalvm.wasm.WasmType.REF_TYPE_HEADER; import static org.graalvm.wasm.WasmType.V128_TYPE; import static org.graalvm.wasm.WasmType.VOID_BLOCK_TYPE; import static org.graalvm.wasm.constants.Bytecode.vectorOpcodeToBytecode; +import static org.graalvm.wasm.constants.BytecodeBitEncoding.ELEM_ITEM_REF_FUNC_ENTRY_PREFIX; +import static org.graalvm.wasm.constants.BytecodeBitEncoding.ELEM_ITEM_GLOBAL_GET_ENTRY_PREFIX; +import static org.graalvm.wasm.constants.BytecodeBitEncoding.ELEM_ITEM_REF_NULL_ENTRY_PREFIX; import static org.graalvm.wasm.constants.Sizes.MAX_MEMORY_64_DECLARATION_SIZE; import static org.graalvm.wasm.constants.Sizes.MAX_MEMORY_DECLARATION_SIZE; import static org.graalvm.wasm.constants.Sizes.MAX_TABLE_DECLARATION_SIZE; @@ -492,9 +494,25 @@ private void readTableSection() { module.limits().checkTableCount(startingTableIndex + tableCount); for (int tableIndex = startingTableIndex; tableIndex != startingTableIndex + tableCount; tableIndex++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); - final int elemType = readRefType(exceptions); - readTableLimits(multiResult); - module.symbolTable().allocateTable(tableIndex, multiResult[0], multiResult[1], elemType, bulkMemoryAndRefTypes); + final int elemType; + final Object initValue; + final byte[] initBytecode; + if (peek1(data, offset) == 0x40 && peek1(data, offset + 1) == 0x00) { + offset += 2; + elemType = readRefType(exceptions); + readTableLimits(multiResult); + Pair initExpression = readConstantExpression(elemType); + initValue = initExpression.getLeft(); + // Drop the initializer bytecode if we can eval the initializer during parsing + initBytecode = initValue == null ? initExpression.getRight() : null; + } else { + elemType = readRefType(exceptions); + readTableLimits(multiResult); + initValue = null; + initBytecode = null; + Assert.assertTrue(WasmType.isNullable(elemType), Failure.UNINITIALIZED_TABLE); + } + module.symbolTable().declareTable(tableIndex, multiResult[0], multiResult[1], elemType, initBytecode, initValue, bulkMemoryAndRefTypes); } } @@ -2628,7 +2646,7 @@ private Pair readOffsetExpression() { // Table offset expression must be a constant expression with result type i32. // https://webassembly.github.io/spec/core/syntax/modules.html#element-segments // https://webassembly.github.io/spec/core/valid/instructions.html#constant-expressions - Pair result = readConstantExpression(I32_TYPE, true); + Pair result = readConstantExpression(I32_TYPE); if (result.getRight() == null) { return Pair.create((int) result.getLeft(), null); } else { @@ -2637,7 +2655,7 @@ private Pair readOffsetExpression() { } private Pair readLongOffsetExpression() { - Pair result = readConstantExpression(I64_TYPE, true); + Pair result = readConstantExpression(I64_TYPE); if (result.getRight() == null) { return Pair.create((long) result.getLeft(), null); } else { @@ -2645,7 +2663,7 @@ private Pair readLongOffsetExpression() { } } - private Pair readConstantExpression(int resultType, boolean onlyImportedGlobals) { + private Pair readConstantExpression(int resultType) { // Read the constant expression. // https://webassembly.github.io/spec/core/valid/instructions.html#constant-expressions final RuntimeBytecodeGen bytecode = new RuntimeBytecodeGen(); @@ -2717,12 +2735,6 @@ private Pair readConstantExpression(int resultType, boolean only break; case Instructions.GLOBAL_GET: { final int index = readGlobalIndex(); - if (onlyImportedGlobals) { - // The current WebAssembly spec says constant expressions can only refer to - // imported globals. We can easily remove this restriction in the future. - assertUnsignedIntLess(index, module.symbolTable().importedGlobals().size(), Failure.UNKNOWN_GLOBAL, - "Constant expression in module '%s' refers to non-imported global %d.", module.name(), index); - } assertIntEqual(module.globalMutability(index), GlobalModifier.CONSTANT, Failure.CONSTANT_EXPRESSION_REQUIRED); state.push(module.symbolTable().globalValueType(index)); state.addUnsignedInstruction(Bytecode.GLOBAL_GET_U8, index); @@ -2804,14 +2816,16 @@ private Pair readConstantExpression(int resultType, boolean only } } - private long[] readFunctionIndices() { + private long[] readFunctionIndices(int elemType) { final int functionIndexCount = readLength(); final long[] functionIndices = new long[functionIndexCount]; for (int index = 0; index != functionIndexCount; index++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); final int functionIndex = readDeclaredFunctionIndex(); module.addFunctionReference(functionIndex); - functionIndices[index] = ((long) FUNCREF_TYPE << 32) | functionIndex; + final int functionReferenceType = WasmType.withNullable(false, module.function(functionIndex).typeIndex()); + Assert.assertTrue(module.matches(elemType, functionReferenceType), Failure.TYPE_MISMATCH); + functionIndices[index] = ((long) ELEM_ITEM_REF_FUNC_ENTRY_PREFIX << 32) | functionIndex; } return functionIndices; } @@ -2825,7 +2839,7 @@ private void checkElemKind() { private long[] readElemExpressions(int elemType) { final int expressionCount = readLength(); - final long[] functionIndices = new long[expressionCount]; + final long[] elements = new long[expressionCount]; for (int index = 0; index != expressionCount; index++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); int opcode = read1() & 0xFF; @@ -2850,21 +2864,21 @@ private long[] readElemExpressions(int elemType) { final int heapType = readHeapType(exceptions); final int nullableReferenceType = WasmType.withNullable(true, heapType); Assert.assertTrue(module.matches(elemType, nullableReferenceType), "Invalid ref.null type: 0x%02X", Failure.TYPE_MISMATCH); - functionIndices[index] = ((long) NULL_TYPE << 32); + elements[index] = ((long) ELEM_ITEM_REF_NULL_ENTRY_PREFIX << 32); break; case Instructions.REF_FUNC: final int functionIndex = readDeclaredFunctionIndex(); module.addFunctionReference(functionIndex); final int functionReferenceType = WasmType.withNullable(false, module.function(functionIndex).typeIndex()); Assert.assertTrue(module.matches(elemType, functionReferenceType), "Invalid element type: 0x%02X", Failure.TYPE_MISMATCH); - functionIndices[index] = ((long) FUNCREF_TYPE << 32) | functionIndex; + elements[index] = ((long) ELEM_ITEM_REF_FUNC_ENTRY_PREFIX << 32) | functionIndex; break; case Instructions.GLOBAL_GET: final int globalIndex = readGlobalIndex(); assertIntEqual(module.globalMutability(globalIndex), GlobalModifier.CONSTANT, Failure.CONSTANT_EXPRESSION_REQUIRED); final int valueType = module.globalValueType(globalIndex); Assert.assertTrue(module.matches(elemType, valueType), Failure.TYPE_MISMATCH); - functionIndices[index] = ((long) I32_TYPE << 32) | globalIndex; + elements[index] = ((long) ELEM_ITEM_GLOBAL_GET_ENTRY_PREFIX << 32) | globalIndex; break; case Instructions.VECTOR: checkSIMDSupport(); @@ -2880,7 +2894,7 @@ private long[] readElemExpressions(int elemType) { } readEnd(); } - return functionIndices; + return elements; } private void readElementSection(RuntimeBytecodeGen bytecode) { @@ -2925,9 +2939,11 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { } else { if (useType) { checkElemKind(); + elemType = FUNCREF_TYPE; + } else { + elemType = WasmType.withNullable(false, FUNC_HEAPTYPE); } - elemType = FUNCREF_TYPE; - elements = readFunctionIndices(); + elements = readFunctionIndices(elemType); } } else { mode = SegmentMode.ACTIVE; @@ -2935,8 +2951,8 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { Pair offsetExpression = readOffsetExpression(); currentOffsetAddress = offsetExpression.getLeft(); currentOffsetBytecode = offsetExpression.getRight(); - elements = readFunctionIndices(); elemType = FUNCREF_TYPE; + elements = readFunctionIndices(elemType); } // Copy the contents, or schedule a linker task for this. @@ -2960,14 +2976,14 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { for (long element : elements) { final int initType = (int) (element >> 32); switch (initType) { - case NULL_TYPE: + case ELEM_ITEM_REF_NULL_ENTRY_PREFIX: bytecode.addElemNull(); break; - case FUNCREF_TYPE: + case ELEM_ITEM_REF_FUNC_ENTRY_PREFIX: final int functionIndex = (int) element; bytecode.addElemFunctionIndex(functionIndex); break; - case I32_TYPE: + case ELEM_ITEM_GLOBAL_GET_ENTRY_PREFIX: final int globalIndex = (int) element; bytecode.addElemGlobalIndex(globalIndex); break; @@ -3056,7 +3072,7 @@ private void readGlobalSection() { final byte mutability = readMutability(); // Global initialization expressions must be constant expressions: // https://webassembly.github.io/spec/core/valid/instructions.html#constant-expressions - Pair initExpression = readConstantExpression(type, true); + Pair initExpression = readConstantExpression(type); final Object initValue = initExpression.getLeft(); final byte[] initBytecode = initExpression.getRight(); final boolean isInitialized = initBytecode == null; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 997d923a6e74..e8413bf152bb 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -84,6 +84,7 @@ import org.graalvm.wasm.Linker.ResolutionDag.ImportTableSym; import org.graalvm.wasm.Linker.ResolutionDag.ImportTagSym; import org.graalvm.wasm.Linker.ResolutionDag.InitializeGlobalSym; +import org.graalvm.wasm.Linker.ResolutionDag.InitializeTableSym; import org.graalvm.wasm.Linker.ResolutionDag.Resolver; import org.graalvm.wasm.Linker.ResolutionDag.Sym; import org.graalvm.wasm.api.ExecuteHostFunctionNode; @@ -790,6 +791,29 @@ void resolvePassiveDataSegment(WasmStore store, WasmInstance instance, int dataS resolutionDag.resolveLater(new DataSym(instance.name(), dataSegmentId), dependencies.toArray(new Sym[0]), resolveAction); } + private static void initializeTable(WasmInstance instance, int tableIndex, Object initValue) { + int tableAddress = instance.tableAddress(tableIndex); + WasmTable table = instance.store().tables().table(tableAddress); + table.fill(0, table.size(), initValue); + } + + void resolveTableInitialization(WasmInstance instance, int tableIndex, byte[] initBytecode, Object initValue) { + final Runnable resolveAction; + final Sym[] dependencies; + if (initValue != null) { + initializeTable(instance, tableIndex, initValue); + resolveAction = NO_RESOLVE_ACTION; + dependencies = ResolutionDag.NO_DEPENDENCIES; + } else if (initBytecode != null) { + resolveAction = () -> initializeTable(instance, tableIndex, evalConstantExpression(instance, initBytecode)); + dependencies = dependenciesOfConstantExpression(instance, initBytecode).toArray(ResolutionDag.NO_DEPENDENCIES); + } else { + resolveAction = NO_RESOLVE_ACTION; + dependencies = ResolutionDag.NO_DEPENDENCIES; + } + resolutionDag.resolveLater(new InitializeTableSym(instance.name(), tableIndex), dependencies, resolveAction); + } + void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int tableIndex, int declaredMinSize, int declaredMaxSize, int elemType, ImportValueSupplier imports) { final Runnable resolveAction = () -> { @@ -829,13 +853,14 @@ void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor assertIntEqual(elemType, importedTable.elemType(), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTableAddress(tableIndex, tableAddress); }; + final ImportTableSym importTableSym = new ImportTableSym(instance.name(), importDescriptor); Sym[] dependencies = new Sym[]{new ExportTableSym(importDescriptor.moduleName(), importDescriptor.memberName())}; - resolutionDag.resolveLater(new ImportTableSym(instance.name(), importDescriptor), dependencies, resolveAction); + resolutionDag.resolveLater(importTableSym, dependencies, resolveAction); + resolutionDag.resolveLater(new InitializeTableSym(instance.name(), tableIndex), new Sym[]{importTableSym}, NO_RESOLVE_ACTION); } void resolveTableExport(WasmModule module, int tableIndex, String exportedTableName) { - final ImportDescriptor importDescriptor = module.symbolTable().importedTable(tableIndex); - final Sym[] dependencies = importDescriptor != null ? new Sym[]{new ImportTableSym(module.name(), importDescriptor)} : ResolutionDag.NO_DEPENDENCIES; + final Sym[] dependencies = new Sym[]{new InitializeTableSym(module.name(), tableIndex)}; resolutionDag.resolveLater(new ExportTableSym(module.name(), exportedTableName), dependencies, NO_RESOLVE_ACTION); } @@ -933,9 +958,7 @@ private static Object[] extractElemItems(WasmInstance instance, int bytecodeOffs void resolveElemSegment(WasmStore store, WasmInstance instance, int tableIndex, int elemSegmentId, int offsetAddress, byte[] offsetBytecode, int bytecodeOffset, int elementCount) { final Runnable resolveAction = () -> immediatelyResolveElemSegment(store, instance, tableIndex, offsetAddress, offsetBytecode, bytecodeOffset, elementCount); final ArrayList dependencies = new ArrayList<>(); - if (instance.symbolTable().importedTable(tableIndex) != null) { - dependencies.add(new ImportTableSym(instance.name(), instance.symbolTable().importedTable(tableIndex))); - } + dependencies.add(new InitializeTableSym(instance.name(), tableIndex)); if (elemSegmentId > 0) { dependencies.add(new ElemSym(instance.name(), elemSegmentId - 1)); } @@ -977,7 +1000,6 @@ void resolvePassiveElemSegment(WasmStore store, WasmInstance instance, int elemS } addElemItemDependencies(instance, bytecodeOffset, elementCount, dependencies); resolutionDag.resolveLater(new ElemSym(instance.name(), elemSegmentId), dependencies.toArray(new Sym[0]), resolveAction); - } public void immediatelyResolvePassiveElementSegment(WasmStore store, WasmInstance instance, int elemSegmentId, int bytecodeOffset, int elementCount) { @@ -1354,6 +1376,30 @@ public boolean equals(Object object) { } } + static class InitializeTableSym extends Sym { + final int tableIndex; + + InitializeTableSym(String moduleName, int tableIndex) { + super(moduleName); + this.tableIndex = tableIndex; + } + + @Override + public String toString() { + return String.format(Locale.ROOT, "(init table %d in %s)", tableIndex, moduleName); + } + + @Override + public int hashCode() { + return Integer.hashCode(tableIndex) ^ moduleName.hashCode(); + } + + @Override + public boolean equals(Object object) { + return object instanceof InitializeTableSym that && this.tableIndex == that.tableIndex && this.moduleName.equals(that.moduleName); + } + } + static class ElemSym extends Sym { final int elemSegmentId; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index d1cbcecce507..57ccc2c06636 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -213,9 +213,9 @@ public int value() { @Override public boolean matches(ClosedHeapType heapSubType) { return switch (this.value) { - case WasmType.FUNC_HEAPTYPE -> this == FUNC || heapSubType instanceof ClosedFunctionType; - case WasmType.EXTERN_HEAPTYPE -> this == EXTERN; - case WasmType.EXN_HEAPTYPE -> this == EXN; + case WasmType.FUNC_HEAPTYPE -> heapSubType == FUNC || heapSubType instanceof ClosedFunctionType; + case WasmType.EXTERN_HEAPTYPE -> heapSubType == EXTERN; + case WasmType.EXN_HEAPTYPE -> heapSubType == EXN; default -> throw CompilerDirectives.shouldNotReachHere(); }; } @@ -296,8 +296,10 @@ public Kind kind() { * Note: this is the upper bound defined by the module. A table instance * might have a lower internal max allowed size in practice. * @param elemType The element type of the table. + * @param initValue The initial value of the table's elements, can be {@code null} if no initializer present + * @param initBytecode The bytecode of the table's initializer expression, can be {@code null} if no initializer present */ - public record TableInfo(int initialSize, int maximumSize, int elemType) { + public record TableInfo(int initialSize, int maximumSize, int elemType, Object initValue, byte[] initBytecode) { } /** @@ -1106,9 +1108,9 @@ private void ensureTableCapacity(int index) { } } - public void allocateTable(int index, int declaredMinSize, int declaredMaxSize, int elemType, boolean referenceTypes) { + public void declareTable(int index, int declaredMinSize, int declaredMaxSize, int elemType, byte[] initBytecode, Object initValue, boolean referenceTypes) { checkNotParsed(); - addTable(index, declaredMinSize, declaredMaxSize, elemType, referenceTypes); + addTable(index, declaredMinSize, declaredMaxSize, elemType, initValue, initBytecode, referenceTypes); module().addLinkAction((context, store, instance, imports) -> { final int maxAllowedSize = minUnsigned(declaredMaxSize, module().limits().tableInstanceSizeLimit()); module().limits().checkTableInstanceSize(declaredMinSize); @@ -1121,12 +1123,14 @@ public void allocateTable(int index, int declaredMinSize, int declaredMaxSize, i } final int address = store.tables().register(wasmTable); instance.setTableAddress(index, address); + + store.linker().resolveTableInitialization(instance, index, initBytecode, initValue); }); } void importTable(String moduleName, String tableName, int index, int initSize, int maxSize, int elemType, boolean referenceTypes) { checkNotParsed(); - addTable(index, initSize, maxSize, elemType, referenceTypes); + addTable(index, initSize, maxSize, elemType, null, null, referenceTypes); final ImportDescriptor importedTable = new ImportDescriptor(moduleName, tableName, ImportIdentifier.TABLE, index, numImportedSymbols()); importedTables.put(index, importedTable); importSymbol(importedTable); @@ -1136,13 +1140,13 @@ void importTable(String moduleName, String tableName, int index, int initSize, i }); } - void addTable(int index, int minSize, int maxSize, int elemType, boolean referenceTypes) { + void addTable(int index, int minSize, int maxSize, int elemType, Object initValue, byte[] initBytecode, boolean referenceTypes) { if (!referenceTypes) { assertTrue(importedTables.isEmpty(), "A table has already been imported in the module.", Failure.MULTIPLE_TABLES); assertTrue(tableCount == 0, "A table has already been declared in the module.", Failure.MULTIPLE_TABLES); } ensureTableCapacity(index); - final TableInfo table = new TableInfo(minSize, maxSize, elemType); + final TableInfo table = new TableInfo(minSize, maxSize, elemType, initValue, initBytecode); tables[index] = table; tableCount++; } @@ -1202,6 +1206,18 @@ public int tableElementType(int index) { return table.elemType; } + public Object tableInitialValue(int index) { + final TableInfo table = tables[index]; + assert table != null; + return table.initValue; + } + + public byte[] tableInitializerBytecode(int index) { + final TableInfo table = tables[index]; + assert table != null; + return table.initBytecode; + } + private void ensureMemoryCapacity(int index) { if (index >= memories.length) { final MemoryInfo[] nMemories = new MemoryInfo[Math.max(Integer.highestOneBit(index) << 1, 2 * memories.length)]; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java index 8f64f7732f77..459c1ef4d303 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java @@ -154,6 +154,10 @@ static List recreateLinkActions(WasmModule module) { final WasmTable wasmTable = new WasmTable(tableMinSize, tableMaxSize, maxAllowedSize, tableElemType); final int address = store.tables().register(wasmTable); instance.setTableAddress(tableIndex, address); + + final byte[] initBytecode = module.tableInitializerBytecode(tableIndex); + final Object initValue = module.tableInitialValue(tableIndex); + store.linker().resolveTableInitialization(instance, tableIndex, initBytecode, initValue); }); } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java index 7260ed3cea3f..9d1a247c6061 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/BytecodeBitEncoding.java @@ -146,6 +146,10 @@ public class BytecodeBitEncoding { // Elem items + public static final int ELEM_ITEM_REF_NULL_ENTRY_PREFIX = 0; + public static final int ELEM_ITEM_REF_FUNC_ENTRY_PREFIX = 1; + public static final int ELEM_ITEM_GLOBAL_GET_ENTRY_PREFIX = 2; + public static final int ELEM_ITEM_TYPE_MASK = 0b1000_0000; public static final int ELEM_ITEM_TYPE_FUNCTION_INDEX = 0b0000_0000; public static final int ELEM_ITEM_TYPE_GLOBAL_INDEX = 0b1000_0000; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java index ee6ddbab11fc..bda6173dff62 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java @@ -107,6 +107,7 @@ public enum Failure { UNKNOWN_REFERENCE(Type.INVALID, "unknown reference"), UNDECLARED_FUNCTION_REFERENCE(Type.INVALID, "undeclared function reference"), UNKNOWN_TAG(Type.INVALID, "unknown tag"), + UNINITIALIZED_TABLE(Type.INVALID, "uninitialized table of non-nullable element type"), // GraalWasm-specific: MODULE_SIZE_LIMIT_EXCEEDED(Type.INVALID, "module size exceeds limit"), diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java index da2d53d956c0..e3a2196e537d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java @@ -121,9 +121,11 @@ protected int defineTable(WasmContext context, WasmModule module, String tableNa throw WasmException.create(Failure.MALFORMED_REFERENCE_TYPE, "Only reference types supported in tables."); } else if (!referenceTypes && type != WasmType.FUNCREF_TYPE) { throw WasmException.create(Failure.UNSPECIFIED_MALFORMED, "Only function types are currently supported in tables."); + } else if (!WasmType.isNullable(type)) { + throw WasmException.create(Failure.UNINITIALIZED_TABLE, "Tables of built-in modules must be nullable."); } int index = module.symbolTable().tableCount(); - module.symbolTable().allocateTable(index, initSize, maxSize, type, referenceTypes); + module.symbolTable().declareTable(index, initSize, maxSize, type, null, null, referenceTypes); module.symbolTable().exportTable(index, tableName); return index; } From 2df775ad0f16f199910d0404ae51a7c69f5b2e5b Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 30 Sep 2025 13:48:47 +0200 Subject: [PATCH 07/40] Fixes for typed references spec tests --- .../src/org/graalvm/wasm/BinaryParser.java | 77 ++++++------------- .../org/graalvm/wasm/exception/Failure.java | 2 +- .../wasm/parser/validation/BlockFrame.java | 8 +- .../wasm/parser/validation/ParserState.java | 4 +- 4 files changed, 29 insertions(+), 62 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index 6c41a1944e28..a6e7e732fac1 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -120,7 +120,7 @@ public class BinaryParser extends BinaryStreamParser { private static final int MAGIC = 0x6d736100; private static final int VERSION = 0x00000001; - private static final int[] EMPTY_LOCALS = new int[0]; + private static final int[] EMPTY_TYPES = new int[0]; private final WasmModule module; private final WasmContext wasmContext; @@ -566,11 +566,7 @@ private CodeEntry readCodeEntry(int functionIndex, IntArrayList locals, int endO for (int index = 0; index != locals.size(); index++) { localTypes[index + paramCount] = locals.get(index); } - int[] resultTypes = new int[function.resultCount()]; - for (int index = 0; index != resultTypes.length; index++) { - resultTypes[index] = function.resultTypeAt(index); - } - return readFunction(functionIndex, localTypes, resultTypes, endOffset, hasNextFunction, bytecode, codeEntryIndex, null); + return readFunction(functionIndex, localTypes, endOffset, hasNextFunction, bytecode, codeEntryIndex, null); } private IntArrayList readCodeEntryLocals() { @@ -622,12 +618,14 @@ private static int[] encapsulateResultType(int type) { }; } - private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultTypes, int sourceCodeEndOffset, boolean hasNextFunction, RuntimeBytecodeGen bytecode, + private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEndOffset, boolean hasNextFunction, RuntimeBytecodeGen bytecode, int codeEntryIndex, EconomicMap offsetToLineIndexMap) { final ParserState state = new ParserState(bytecode, module); final ArrayList callNodes = new ArrayList<>(); final int bytecodeStartOffset = bytecode.location(); - state.enterFunction(resultTypes, locals); + int[] paramTypes = module.function(functionIndex).paramTypes(); + int[] resultTypes = module.function(functionIndex).resultTypes(); + state.enterFunction(paramTypes, resultTypes, locals); int opcode; end: while (offset < sourceCodeEndOffset) { @@ -702,7 +700,6 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType break; } case Instructions.IF: { - state.popChecked(I32_TYPE); // condition final int[] ifParamTypes; final int[] ifResultTypes; readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); @@ -724,6 +721,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType } default -> throw WasmException.create(Failure.DISABLED_MULTI_VALUE); } + state.popChecked(I32_TYPE); // condition state.popAll(ifParamTypes); state.enterIf(ifParamTypes, ifResultTypes); break; @@ -906,8 +904,8 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int[] resultType checkExceptionHandlingSupport(opcode); final int tagIndex = readTagIndex(); final int typeIndex = module.tagTypeIndex(tagIndex); - final int[] paramTypes = module.functionTypeParamTypesAsArray(typeIndex); - state.popAll(paramTypes); + final int[] tagParamTypes = module.functionTypeParamTypesAsArray(typeIndex); + state.popAll(tagParamTypes); state.addMiscFlag(); state.addInstruction(Bytecode.THROW, tagIndex); @@ -2672,7 +2670,7 @@ private Pair readConstantExpression(int resultType) { final List stack = new ArrayList<>(); boolean calculable = true; - state.enterFunction(new int[]{resultType}, EMPTY_LOCALS); + state.enterFunction(EMPTY_TYPES, new int[]{resultType}, EMPTY_TYPES); int opcode; while ((opcode = read1() & 0xFF) != Instructions.END) { switch (opcode) { @@ -3206,41 +3204,24 @@ private void readFunctionType() { protected int readValueType(boolean allowRefTypes, boolean allowVecType, boolean allowExnType) { final int type = readSignedInt32(); - switch (type) { - case I32_TYPE, I64_TYPE, F32_TYPE, F64_TYPE -> { - return type; - } + return switch (type) { + case I32_TYPE, I64_TYPE, F32_TYPE, F64_TYPE -> type; case V128_TYPE -> { Assert.assertTrue(allowVecType, Failure.MALFORMED_VALUE_TYPE); - return type; + yield type; } case FUNCREF_TYPE, EXTERNREF_TYPE -> { Assert.assertTrue(allowRefTypes, Failure.MALFORMED_VALUE_TYPE); - return type; + yield type; } case EXNREF_TYPE -> { Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); - return type; - } - case REF_NULL_TYPE_HEADER, REF_TYPE_HEADER -> { - boolean nullable = type == REF_NULL_TYPE_HEADER; - int heapType = readSignedInt32(); - return switch (heapType) { - case FUNC_HEAPTYPE, EXTERN_HEAPTYPE -> WasmType.withNullable(nullable, heapType); - case EXN_HEAPTYPE -> { - Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); - yield WasmType.withNullable(nullable, heapType); - } - default -> { - if (heapType < 0 || heapType > WasmType.MAX_TYPE_INDEX) { - throw fail(Failure.MALFORMED_REFERENCE_TYPE, "Invalid heap type"); - } - yield WasmType.withNullable(nullable, heapType); - } - }; + yield type; } + case REF_NULL_TYPE_HEADER -> WasmType.withNullable(true, readHeapType(allowExnType)); + case REF_TYPE_HEADER -> WasmType.withNullable(false, readHeapType(allowExnType)); default -> throw Assert.fail(Failure.MALFORMED_VALUE_TYPE, "Invalid value type: 0x%02X", type); - } + }; } /** @@ -3278,20 +3259,8 @@ protected void readBlockType(int[] result, boolean allowRefTypes, boolean allowV } case REF_NULL_TYPE_HEADER, REF_TYPE_HEADER -> { boolean nullable = type == REF_NULL_TYPE_HEADER; - int heapType = readSignedInt32(); - result[0] = switch (heapType) { - case FUNC_HEAPTYPE, EXTERN_HEAPTYPE -> WasmType.withNullable(nullable, heapType); - case EXN_HEAPTYPE -> { - Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); - yield WasmType.withNullable(nullable, heapType); - } - default -> { - if (heapType < 0 || heapType > WasmType.MAX_TYPE_INDEX) { - throw fail(Failure.MALFORMED_REFERENCE_TYPE, "Invalid heap type"); - } - yield WasmType.withNullable(nullable, heapType); - } - }; + int heapType = readHeapType(allowExnType); + result[0] = WasmType.withNullable(nullable, heapType); result[1] = BLOCK_TYPE_VALTYPE; } default -> { @@ -3423,8 +3392,8 @@ private int readHeapType(boolean allowExnType) { yield heapType; } default -> { - if (heapType < 0 || heapType > WasmType.MAX_TYPE_INDEX) { - throw fail(Failure.MALFORMED_HEAP_TYPE, "Unexpected heap type"); + if (heapType < 0 || heapType >= module.typeCount()) { + throw fail(Failure.UNKNOWN_TYPE, "Unknown heap type %d", heapType); } yield heapType; } @@ -3639,7 +3608,7 @@ public Pair createFunctionDebugBytecode(int functionIndex, Ec final CodeEntry codeEntry = BytecodeParser.readCodeEntry(module, module.bytecode(), codeEntryIndex); offset = module.functionSourceCodeInstructionOffset(functionIndex); final int endOffset = module.functionSourceCodeEndOffset(functionIndex); - final CodeEntry result = readFunction(functionIndex, codeEntry.localTypes(), codeEntry.resultTypes(), endOffset, true, bytecode, codeEntryIndex, offsetToLineIndexMap); + final CodeEntry result = readFunction(functionIndex, codeEntry.localTypes(), endOffset, true, bytecode, codeEntryIndex, offsetToLineIndexMap); return Pair.create(result, bytecode.toArray()); } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java index bda6173dff62..a1e31b6f0477 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java @@ -147,7 +147,7 @@ public enum Failure { INVALID_MULTI_VALUE_ARITY(Type.TRAP, "provided multi-value size does not match function type"), INVALID_TYPE_IN_MULTI_VALUE(Type.TRAP, "type of value in multi-value does not match the function type"), - NULL_REFERENCE(Type.TRAP, "defined element is ref.null"), + NULL_REFERENCE(Type.TRAP, "null reference"), NULL_FUNCTION_REFERENCE(Type.TRAP, "null function reference"), OUT_OF_BOUNDS_TABLE_ACCESS(Type.TRAP, "out of bounds table access"), // GraalWasm-specific: diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java index 1da77891fd6d..4dfb545a06d8 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java @@ -54,8 +54,6 @@ * Representation of a wasm block during module validation. */ class BlockFrame extends ControlFrame { - private static final int[] EMPTY_ARRAY = new int[0]; - private final IntArrayList branches; private final ArrayList exceptionHandlers; @@ -69,14 +67,14 @@ private BlockFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, Bi this(paramTypes, resultTypes, initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); } - static BlockFrame createFunctionFrame(int[] resultTypes, int[] locals) { + static BlockFrame createFunctionFrame(int[] paramTypes, int[] resultTypes, int[] locals) { BitSet initializedLocals = new BitSet(locals.length); for (int localIndex = 0; localIndex < locals.length; localIndex++) { - if (WasmType.hasDefaultValue(locals[localIndex])) { + if (localIndex < paramTypes.length || WasmType.hasDefaultValue(locals[localIndex])) { initializedLocals.set(localIndex); } } - return new BlockFrame(EMPTY_ARRAY, resultTypes, 0, initializedLocals); + return new BlockFrame(paramTypes, resultTypes, 0, initializedLocals); } @Override diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java index cc34209feb4d..043c8dd4a521 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java @@ -265,8 +265,8 @@ private void unwindStack(int size) { } } - public void enterFunction(int[] resultTypes, int[] locals) { - ControlFrame frame = BlockFrame.createFunctionFrame(resultTypes, locals); + public void enterFunction(int[] paramTypes, int[] resultTypes, int[] locals) { + ControlFrame frame = BlockFrame.createFunctionFrame(paramTypes, resultTypes, locals); controlStack.push(frame); } From 48ad15bbabcd09a49c2370f28de4afcfbf95c9bf Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Mon, 6 Oct 2025 17:37:35 +0200 Subject: [PATCH 08/40] Fix br_table type checking in the presence of subtyping --- .../wasm/parser/validation/ParserState.java | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java index 043c8dd4a521..cda383b27bcc 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java @@ -458,12 +458,19 @@ public void addBranchTable(int[] branchLabels) { checkLabelExists(branchLabel); ControlFrame frame = getFrame(branchLabel); int[] branchLabelReturnTypes = frame.labelTypes(); + int arity = branchLabelReturnTypes.length; for (int otherBranchLabel : branchLabels) { checkLabelExists(otherBranchLabel); frame = getFrame(otherBranchLabel); int[] otherBranchLabelReturnTypes = frame.labelTypes(); - checkLabelTypes(branchLabelReturnTypes, otherBranchLabelReturnTypes); - pushAll(popAll(otherBranchLabelReturnTypes)); + if (otherBranchLabelReturnTypes.length != arity) { + throw ValidationErrors.createLabelTypesMismatch(branchLabelReturnTypes, otherBranchLabelReturnTypes); + } + try { + pushAll(popAll(otherBranchLabelReturnTypes)); + } catch (WasmException e) { + throw ValidationErrors.createLabelTypesMismatch(branchLabelReturnTypes, otherBranchLabelReturnTypes); + } frame.addBranchTableItem(bytecode); } popAll(branchLabelReturnTypes); @@ -782,19 +789,6 @@ public void checkLabelExists(int label) { } } - /** - * Checks if the value types of two different labels match. - * - * @param expectedTypes The expected value types. - * @param actualTypes The value types that should be checked. - * @throws WasmException If the provided sets of value types do not match. - */ - public void checkLabelTypes(int[] expectedTypes, int[] actualTypes) { - if (isTypeMismatch(expectedTypes, actualTypes)) { - throw ValidationErrors.createLabelTypesMismatch(expectedTypes, actualTypes); - } - } - /** * Sets the current control stack unreachable. */ From 0c2224338d808a1c8700c6d4cc358bc86de3be24 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 7 Oct 2025 00:43:43 +0200 Subject: [PATCH 09/40] Added a context option for typed references, updated CHANGELOG Also fixed up code style in typed function references implementation. --- wasm/CHANGELOG.md | 1 + .../ReferenceTypesValidationSuite.java | 3 +- .../src/org/graalvm/wasm/BinaryParser.java | 92 ++++++++++++------- .../src/org/graalvm/wasm/SymbolTable.java | 14 ++- .../org/graalvm/wasm/WasmContextOptions.java | 10 ++ .../org/graalvm/wasm/WasmInstantiator.java | 2 +- .../src/org/graalvm/wasm/WasmOptions.java | 3 + .../src/org/graalvm/wasm/WasmType.java | 3 +- .../graalvm/wasm/nodes/WasmFunctionNode.java | 6 +- .../parser/bytecode/RuntimeBytecodeGen.java | 4 +- 10 files changed, 91 insertions(+), 47 deletions(-) diff --git a/wasm/CHANGELOG.md b/wasm/CHANGELOG.md index ae0151d1fb8e..1a38a9544fb3 100644 --- a/wasm/CHANGELOG.md +++ b/wasm/CHANGELOG.md @@ -5,6 +5,7 @@ This changelog summarizes major changes to the WebAssembly engine implemented in ## Version 25.1.0 * Implemented the [exception handling](https://github.com/WebAssembly/exception-handling) proposal. This feature can be enabled with the experimental option `wasm.Exceptions=true`. +* Implemented the [typed function references](https://github.com/WebAssembly/function-references) proposal. This feature can be enabled with the experimental option `wasm.TypedFunctionReferences=true`. ## Version 25.0.0 diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java index 8488708c5abb..8307423ee8fc 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java @@ -535,8 +535,7 @@ public void testMultipleTables() throws IOException { // (table 1 1 funcref) // (table 1 1 externref) // (table 1 1 exnref) - final byte[] binary = newBuilder().addTable(1, 1, WasmType.FUNCREF_TYPE).addTable(1, 1, WasmType.EXTERNREF_TYPE) - .addTable(1, 1, WasmType.EXNREF_TYPE).build(); + final byte[] binary = newBuilder().addTable(1, 1, WasmType.FUNCREF_TYPE).addTable(1, 1, WasmType.EXTERNREF_TYPE).addTable(1, 1, WasmType.EXNREF_TYPE).build(); runParserTest(binary, options -> options.option("wasm.Exceptions", "true"), Context::eval); runParserTest(binary, Context::eval); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index a6e7e732fac1..1559b020153f 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -135,6 +135,7 @@ public class BinaryParser extends BinaryStreamParser { private final boolean threads; private final boolean simd; private final boolean exceptions; + private final boolean typedFunctionReferences; @TruffleBoundary public BinaryParser(WasmModule module, WasmContext context, byte[] data) { @@ -151,6 +152,7 @@ public BinaryParser(WasmModule module, WasmContext context, byte[] data) { this.threads = context.getContextOptions().supportThreads(); this.simd = context.getContextOptions().supportSIMD(); this.exceptions = context.getContextOptions().supportExceptions(); + this.typedFunctionReferences = context.getContextOptions().supportTypedFunctionReferences(); } @TruffleBoundary @@ -437,7 +439,7 @@ private void readImportSection() { break; } case ImportIdentifier.TABLE: { - final int elemType = readRefType(exceptions); + final int elemType = readRefType(); if (!bulkMemoryAndRefTypes) { assertIntEqual(elemType, FUNCREF_TYPE, Failure.UNSPECIFIED_MALFORMED, "Invalid element type for table import"); } @@ -455,7 +457,7 @@ private void readImportSection() { break; } case ImportIdentifier.GLOBAL: { - int type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + int type = readValueType(); byte mutability = readMutability(); int globalIndex = module.symbolTable().numGlobals(); module.symbolTable().importGlobal(moduleName, memberName, globalIndex, type, mutability); @@ -498,15 +500,16 @@ private void readTableSection() { final Object initValue; final byte[] initBytecode; if (peek1(data, offset) == 0x40 && peek1(data, offset + 1) == 0x00) { + Assert.assertTrue(typedFunctionReferences, Failure.MALFORMED_VALUE_TYPE); offset += 2; - elemType = readRefType(exceptions); + elemType = readRefType(); readTableLimits(multiResult); Pair initExpression = readConstantExpression(elemType); initValue = initExpression.getLeft(); // Drop the initializer bytecode if we can eval the initializer during parsing initBytecode = initValue == null ? initExpression.getRight() : null; } else { - elemType = readRefType(exceptions); + elemType = readRefType(); readTableLimits(multiResult); initValue = null; initBytecode = null; @@ -578,7 +581,7 @@ private IntArrayList readCodeEntryLocals() { final int groupLength = readUnsignedInt32(); localsLength += groupLength; module.limits().checkLocalCount(localsLength); - final int t = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + final int t = readValueType(); for (int i = 0; i != groupLength; ++i) { localTypes.add(t); } @@ -649,7 +652,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn case Instructions.BLOCK: { final int[] blockParamTypes; final int[] blockResultTypes; - readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); + readBlockType(multiResult); // Extract value based on result arity. switch (multiResult[1]) { case BLOCK_TYPE_VOID -> { @@ -676,7 +679,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn // Jumps are targeting the loop instruction for OSR. final int[] loopParamTypes; final int[] loopResultTypes; - readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); + readBlockType(multiResult); // Extract value based on result arity. switch (multiResult[1]) { case BLOCK_TYPE_VOID -> { @@ -702,7 +705,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn case Instructions.IF: { final int[] ifParamTypes; final int[] ifResultTypes; - readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); + readBlockType(multiResult); // Extract value based on result arity. switch (multiResult[1]) { case BLOCK_TYPE_VOID -> { @@ -860,7 +863,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn checkBulkMemoryAndRefTypesSupport(opcode); final int length = readLength(); assertIntEqual(length, 1, Failure.INVALID_RESULT_ARITY); - final int t = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + final int t = readValueType(); state.popChecked(I32_TYPE); state.popChecked(t); state.popChecked(t); @@ -876,7 +879,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn checkExceptionHandlingSupport(opcode); final int[] tryTableParamTypes; final int[] tryTableResultTypes; - readBlockType(multiResult, bulkMemoryAndRefTypes, simd, exceptions); + readBlockType(multiResult); // Extract value based on result arity. switch (multiResult[1]) { case BLOCK_TYPE_VOID -> { @@ -1127,6 +1130,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn break; } case Instructions.CALL_REF: { + checkTypedFunctionReferencesSupport(opcode); final int expectedFunctionTypeIndex = readTypeIndex(); final int functionReferenceType = WasmType.withNullable(true, expectedFunctionTypeIndex); state.popChecked(functionReferenceType); @@ -1148,6 +1152,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn break; } case Instructions.REF_AS_NON_NULL: { + checkTypedFunctionReferencesSupport(opcode); final int referenceType = state.popReferenceTypeChecked(); final int nonNullReferenceType = WasmType.withNullable(false, referenceType); state.push(nonNullReferenceType); @@ -1156,6 +1161,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn break; } case Instructions.BR_ON_NULL: { + checkTypedFunctionReferencesSupport(opcode); final int branchLabel = readTargetOffset(); final int referenceType = state.popReferenceTypeChecked(); state.addBranchOnNull(branchLabel); @@ -1164,6 +1170,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn break; } case Instructions.BR_ON_NON_NULL: { + checkTypedFunctionReferencesSupport(opcode); final int branchLabel = readTargetOffset(); final int referenceType = state.popReferenceTypeChecked(); final int nonNullReferenceType = WasmType.withNullable(false, referenceType); @@ -1705,7 +1712,7 @@ private void readNumericInstructions(ParserState state, int opcode) { break; case Instructions.REF_NULL: checkBulkMemoryAndRefTypesSupport(opcode); - final int heapType = readHeapType(exceptions); + final int heapType = readHeapType(); final int nullableReferenceType = WasmType.withNullable(true, heapType); state.push(nullableReferenceType); state.addInstruction(Bytecode.REF_NULL); @@ -2480,6 +2487,10 @@ private void checkExceptionHandlingSupport(int opcode) { checkContextOption(wasmContext.getContextOptions().supportExceptions(), "Exception handling is not enabled (opcode: 0x%02x)", opcode); } + private void checkTypedFunctionReferencesSupport(int opcode) { + checkContextOption(wasmContext.getContextOptions().supportTypedFunctionReferences(), "Typed function references are not enabled (opcode: 0x%02x)", opcode); + } + private void store(ParserState state, int type, int n, long[] result) { int alignHint = readAlignHint(n); final int memoryIndex = readMemoryIndexFromAlignHint(alignHint); @@ -2714,7 +2725,7 @@ private Pair readConstantExpression(int resultType) { } case Instructions.REF_NULL: checkBulkMemoryAndRefTypesSupport(opcode); - final int heapType = readHeapType(exceptions); + final int heapType = readHeapType(); final int nullableReferenceType = WasmType.withNullable(true, heapType); state.push(nullableReferenceType); state.addInstruction(Bytecode.REF_NULL); @@ -2859,7 +2870,7 @@ private long[] readElemExpressions(int elemType) { throw WasmException.format(Failure.ILLEGAL_OPCODE, "Illegal opcode for constant expression: 0x%02X", opcode); } case Instructions.REF_NULL: - final int heapType = readHeapType(exceptions); + final int heapType = readHeapType(); final int nullableReferenceType = WasmType.withNullable(true, heapType); Assert.assertTrue(module.matches(elemType, nullableReferenceType), "Invalid ref.null type: 0x%02X", Failure.TYPE_MISMATCH); elements[index] = ((long) ELEM_ITEM_REF_NULL_ENTRY_PREFIX << 32); @@ -2929,7 +2940,7 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { } if (useExpressions) { if (useType) { - elemType = readRefType(exceptions); + elemType = readRefType(); } else { elemType = FUNCREF_TYPE; } @@ -3065,7 +3076,7 @@ private void readGlobalSection() { final int startingGlobalIndex = module.symbolTable().numGlobals(); for (int globalIndex = startingGlobalIndex; globalIndex != startingGlobalIndex + globalCount; globalIndex++) { assertTrue(!isEOF(), Failure.LENGTH_OUT_OF_BOUNDS); - final int type = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + final int type = readValueType(); // 0x00 means const, 0x01 means var final byte mutability = readMutability(); // Global initialization expressions must be constant expressions: @@ -3183,14 +3194,14 @@ private void readFunctionType() { module.limits().checkParamCount(paramCount); int[] paramTypes = new int[paramCount]; for (int paramIdx = 0; paramIdx < paramCount; paramIdx++) { - paramTypes[paramIdx] = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + paramTypes[paramIdx] = readValueType(); } int resultCount = readLength(); module.limits().checkResultCount(resultCount, multiValue); int[] resultTypes = new int[resultCount]; for (int resultIdx = 0; resultIdx < resultCount; resultIdx++) { - resultTypes[resultIdx] = readValueType(bulkMemoryAndRefTypes, simd, exceptions); + resultTypes[resultIdx] = readValueType(); } int funcTypeIdx = module.symbolTable().allocateFunctionType(paramCount, resultCount, multiValue); @@ -3202,24 +3213,30 @@ private void readFunctionType() { } } - protected int readValueType(boolean allowRefTypes, boolean allowVecType, boolean allowExnType) { + protected int readValueType() { final int type = readSignedInt32(); return switch (type) { case I32_TYPE, I64_TYPE, F32_TYPE, F64_TYPE -> type; case V128_TYPE -> { - Assert.assertTrue(allowVecType, Failure.MALFORMED_VALUE_TYPE); + Assert.assertTrue(simd, Failure.MALFORMED_VALUE_TYPE); yield type; } case FUNCREF_TYPE, EXTERNREF_TYPE -> { - Assert.assertTrue(allowRefTypes, Failure.MALFORMED_VALUE_TYPE); + Assert.assertTrue(bulkMemoryAndRefTypes, Failure.MALFORMED_VALUE_TYPE); yield type; } case EXNREF_TYPE -> { - Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); + Assert.assertTrue(exceptions, Failure.MALFORMED_VALUE_TYPE); yield type; } - case REF_NULL_TYPE_HEADER -> WasmType.withNullable(true, readHeapType(allowExnType)); - case REF_TYPE_HEADER -> WasmType.withNullable(false, readHeapType(allowExnType)); + case REF_NULL_TYPE_HEADER -> { + Assert.assertTrue(typedFunctionReferences, Failure.MALFORMED_VALUE_TYPE); + yield WasmType.withNullable(true, readHeapType()); + } + case REF_TYPE_HEADER -> { + Assert.assertTrue(typedFunctionReferences, Failure.MALFORMED_VALUE_TYPE); + yield WasmType.withNullable(false, readHeapType()); + } default -> throw Assert.fail(Failure.MALFORMED_VALUE_TYPE, "Invalid value type: 0x%02X", type); }; } @@ -3232,7 +3249,7 @@ protected int readValueType(boolean allowRefTypes, boolean allowVecType, boolean * @param result The array used for returning the result. * */ - protected void readBlockType(int[] result, boolean allowRefTypes, boolean allowVecType, boolean allowExnType) { + protected void readBlockType(int[] result) { int type = readSignedInt32(); switch (type) { case VOID_BLOCK_TYPE -> { @@ -3243,23 +3260,23 @@ protected void readBlockType(int[] result, boolean allowRefTypes, boolean allowV result[1] = BLOCK_TYPE_VALTYPE; } case V128_TYPE -> { - Assert.assertTrue(allowVecType, Failure.MALFORMED_VALUE_TYPE); + Assert.assertTrue(simd, Failure.MALFORMED_VALUE_TYPE); result[0] = type; result[1] = BLOCK_TYPE_VALTYPE; } case FUNCREF_TYPE, EXTERNREF_TYPE -> { - Assert.assertTrue(allowRefTypes, Failure.MALFORMED_VALUE_TYPE); + Assert.assertTrue(bulkMemoryAndRefTypes, Failure.MALFORMED_VALUE_TYPE); result[0] = type; result[1] = BLOCK_TYPE_VALTYPE; } case EXNREF_TYPE -> { - Assert.assertTrue(allowExnType, Failure.MALFORMED_VALUE_TYPE); + Assert.assertTrue(exceptions, Failure.MALFORMED_VALUE_TYPE); result[0] = type; result[1] = BLOCK_TYPE_VALTYPE; } case REF_NULL_TYPE_HEADER, REF_TYPE_HEADER -> { boolean nullable = type == REF_NULL_TYPE_HEADER; - int heapType = readHeapType(allowExnType); + int heapType = readHeapType(); result[0] = WasmType.withNullable(nullable, heapType); result[1] = BLOCK_TYPE_VALTYPE; } @@ -3369,32 +3386,39 @@ private byte readImportType() { return read1(); } - private int readRefType(boolean allowExnType) { + private int readRefType() { final int refType = readSignedInt32(); return switch (refType) { case FUNCREF_TYPE, EXTERNREF_TYPE -> refType; case EXNREF_TYPE -> { - assertTrue(allowExnType, Failure.MALFORMED_REFERENCE_TYPE); + assertTrue(exceptions, Failure.MALFORMED_REFERENCE_TYPE); yield refType; } - case REF_NULL_TYPE_HEADER -> WasmType.withNullable(true, readHeapType(allowExnType)); - case REF_TYPE_HEADER -> WasmType.withNullable(false, readHeapType(allowExnType)); + case REF_NULL_TYPE_HEADER -> { + assertTrue(typedFunctionReferences, Failure.MALFORMED_REFERENCE_TYPE); + yield WasmType.withNullable(true, readHeapType()); + } + case REF_TYPE_HEADER -> { + assertTrue(typedFunctionReferences, Failure.MALFORMED_REFERENCE_TYPE); + yield WasmType.withNullable(false, readHeapType()); + } default -> throw fail(Failure.MALFORMED_REFERENCE_TYPE, "Unexpected reference type"); }; } - private int readHeapType(boolean allowExnType) { + private int readHeapType() { int heapType = readSignedInt32(); return switch (heapType) { case FUNC_HEAPTYPE, EXTERN_HEAPTYPE -> heapType; case EXN_HEAPTYPE -> { - assertTrue(allowExnType, Failure.MALFORMED_HEAP_TYPE); + assertTrue(exceptions, Failure.MALFORMED_HEAP_TYPE); yield heapType; } default -> { if (heapType < 0 || heapType >= module.typeCount()) { throw fail(Failure.UNKNOWN_TYPE, "Unknown heap type %d", heapType); } + assertTrue(typedFunctionReferences, Failure.MALFORMED_HEAP_TYPE); yield heapType; } }; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index 57ccc2c06636..7de8d1538cf2 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -186,7 +186,8 @@ public ClosedHeapType heapType() { @Override public boolean matches(ClosedValueType valueSubType) { - return valueSubType == BottomType.BOTTOM || valueSubType instanceof ClosedReferenceType referenceSubType && (!referenceSubType.nullable || this.nullable) && this.closedHeapType.matches(referenceSubType.closedHeapType); + return valueSubType == BottomType.BOTTOM || valueSubType instanceof ClosedReferenceType referenceSubType && (!referenceSubType.nullable || this.nullable) && + this.closedHeapType.matches(referenceSubType.closedHeapType); } @Override @@ -276,7 +277,8 @@ public Kind kind() { public static final class BottomType extends ClosedValueType { public static final BottomType BOTTOM = new BottomType(); - private BottomType() {} + private BottomType() { + } @Override public boolean matches(ClosedValueType valueSubType) { @@ -296,8 +298,10 @@ public Kind kind() { * Note: this is the upper bound defined by the module. A table instance * might have a lower internal max allowed size in practice. * @param elemType The element type of the table. - * @param initValue The initial value of the table's elements, can be {@code null} if no initializer present - * @param initBytecode The bytecode of the table's initializer expression, can be {@code null} if no initializer present + * @param initValue The initial value of the table's elements, can be {@code null} if no + * initializer present + * @param initBytecode The bytecode of the table's initializer expression, can be {@code null} + * if no initializer present */ public record TableInfo(int initialSize, int maximumSize, int elemType, Object initValue, byte[] initBytecode) { } @@ -604,7 +608,7 @@ private void checkNotParsed() { private void checkUniqueExport(String name) { CompilerAsserts.neverPartOfCompilation(); - if (exportedFunctions.containsKey(name) || exportedGlobals.containsKey(name) || exportedMemories.containsKey(name) || exportedTables.containsKey(name)) { + if (exportedFunctions.containsKey(name) || exportedGlobals.containsKey(name) || exportedMemories.containsKey(name) || exportedTables.containsKey(name) || exportedTags.containsKey(name)) { throw WasmException.create(Failure.DUPLICATE_EXPORT, "All export names must be different, but '" + name + "' is exported twice."); } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmContextOptions.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmContextOptions.java index 70667ba2e074..085bd6f34cb1 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmContextOptions.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmContextOptions.java @@ -61,6 +61,7 @@ public final class WasmContextOptions { @CompilationFinal private boolean simd; @CompilationFinal private boolean relaxedSimd; @CompilationFinal private boolean exceptions; + @CompilationFinal private boolean typedFunctionReferences; @CompilationFinal private boolean memoryOverheadMode; @CompilationFinal private boolean constantRandomGet; @@ -93,6 +94,7 @@ private void setOptionValues() { this.simd = readBooleanOption(WasmOptions.SIMD); this.relaxedSimd = readBooleanOption(WasmOptions.RelaxedSIMD); this.exceptions = readBooleanOption(WasmOptions.Exceptions); + this.typedFunctionReferences = readBooleanOption(WasmOptions.TypedFunctionReferences); this.memoryOverheadMode = readBooleanOption(WasmOptions.MemoryOverheadMode); this.constantRandomGet = readBooleanOption(WasmOptions.WasiConstantRandomGet); this.directByteBufferMemoryAccess = readBooleanOption(WasmOptions.DirectByteBufferMemoryAccess); @@ -165,6 +167,10 @@ public boolean supportExceptions() { return exceptions; } + public boolean supportTypedFunctionReferences() { + return typedFunctionReferences; + } + public boolean memoryOverheadMode() { return memoryOverheadMode; } @@ -199,6 +205,7 @@ public int hashCode() { hash = 53 * hash + (this.simd ? 1 : 0); hash = 53 * hash + (this.relaxedSimd ? 1 : 0); hash = 53 * hash + (this.exceptions ? 1 : 0); + hash = 53 * hash + (this.typedFunctionReferences ? 1 : 0); hash = 53 * hash + (this.memoryOverheadMode ? 1 : 0); hash = 53 * hash + (this.constantRandomGet ? 1 : 0); hash = 53 * hash + (this.directByteBufferMemoryAccess ? 1 : 0); @@ -251,6 +258,9 @@ public boolean equals(Object obj) { if (this.exceptions != other.exceptions) { return false; } + if (this.typedFunctionReferences != other.typedFunctionReferences) { + return false; + } if (this.memoryOverheadMode != other.memoryOverheadMode) { return false; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java index 459c1ef4d303..50d3c4ac9e5a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java @@ -156,7 +156,7 @@ static List recreateLinkActions(WasmModule module) { instance.setTableAddress(tableIndex, address); final byte[] initBytecode = module.tableInitializerBytecode(tableIndex); - final Object initValue = module.tableInitialValue(tableIndex); + final Object initValue = module.tableInitialValue(tableIndex); store.linker().resolveTableInitialization(instance, tableIndex, initBytecode, initValue); }); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmOptions.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmOptions.java index 3b8100ebea56..54dc2e7f2538 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmOptions.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmOptions.java @@ -139,6 +139,9 @@ public enum ConstantsStorePolicy { @Option(help = "Enable support for exception handling", category = OptionCategory.EXPERT, stability = OptionStability.EXPERIMENTAL, usageSyntax = "false|true") // public static final OptionKey Exceptions = new OptionKey<>(false); + @Option(help = "Enable support for typed function references", category = OptionCategory.EXPERT, stability = OptionStability.EXPERIMENTAL, usageSyntax = "false|true") // + public static final OptionKey TypedFunctionReferences = new OptionKey<>(false); + @Option(help = "In this mode memories and tables are not initialized.", category = OptionCategory.INTERNAL, stability = OptionStability.EXPERIMENTAL, usageSyntax = "false|true") // public static final OptionKey MemoryOverheadMode = new OptionKey<>(false); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java index ad0bcc9f204a..a4b350b2b539 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java @@ -163,7 +163,8 @@ public static boolean isVectorType(int type) { } public static boolean isReferenceType(int type) { - return isConcreteReferenceType(type) || withNullable(true, type) == FUNC_HEAPTYPE || withNullable(true, type) == EXTERN_HEAPTYPE || withNullable(true, type) == EXN_HEAPTYPE || isBottomType(type); + return isConcreteReferenceType(type) || withNullable(true, type) == FUNC_HEAPTYPE || withNullable(true, type) == EXTERN_HEAPTYPE || withNullable(true, type) == EXN_HEAPTYPE || + isBottomType(type); } public static boolean isBottomType(int type) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index 22ad8abdebfc..4ff55469a263 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -1737,7 +1737,8 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i Object reference = popReference(frame, --stackPointer); if (profileCondition(bytecode, offset + 1, reference == WasmConstant.NULL)) { final int offsetDelta = rawPeekU8(bytecode, offset); - // BR_ON_NULL_U8 encodes the back jump value as a positive byte value. BR_ON_NULL_U8 + // BR_ON_NULL_U8 encodes the back jump value as a positive byte + // value. BR_ON_NULL_U8 // can never perform a forward jump. offset -= offsetDelta; } else { @@ -1761,7 +1762,8 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i Object reference = popReference(frame, --stackPointer); if (profileCondition(bytecode, offset + 1, reference != WasmConstant.NULL)) { final int offsetDelta = rawPeekU8(bytecode, offset); - // BR_ON_NULL_U8 encodes the back jump value as a positive byte value. BR_ON_NULL_U8 + // BR_ON_NULL_U8 encodes the back jump value as a positive byte + // value. BR_ON_NULL_U8 // can never perform a forward jump. offset -= offsetDelta; pushReference(frame, stackPointer++, reference); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java index e63b7e6e9845..a4d3eb8441c2 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java @@ -623,8 +623,8 @@ public void addIndirectCall(int nodeIndex, int typeIndex, int tableIndex) { /** * Adds a reference call instruction to the bytecode. If the nodeIndex and typeIndex both fit - * into a u8 value, a call_ref_u8 and two u8 values are added. Otherwise, a call_ref_i32 and - * two i32 values are added. + * into a u8 value, a call_ref_u8 and two u8 values are added. Otherwise, a call_ref_i32 and two + * i32 values are added. * * @param nodeIndex The node index of the reference call * @param typeIndex The type index of the reference call From 16d9a906bad94ecd50c189bd224d02586ebe0b63 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 7 Oct 2025 01:19:26 +0200 Subject: [PATCH 10/40] Avoid GraalWasm code duplcation when emitting branches --- .../test/suites/bytecode/BytecodeSuite.java | 20 +-- .../parser/bytecode/RuntimeBytecodeGen.java | 157 ++++++------------ .../wasm/parser/validation/BlockFrame.java | 19 +-- .../wasm/parser/validation/ControlFrame.java | 19 +-- .../wasm/parser/validation/IfFrame.java | 21 +-- .../wasm/parser/validation/LoopFrame.java | 19 +-- .../wasm/parser/validation/ParserState.java | 8 +- 7 files changed, 74 insertions(+), 189 deletions(-) diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java index 6a062a4ca2af..fab3b7403549 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java @@ -159,7 +159,7 @@ public void testInvalidResultType() { @Test public void testBrU8Min() { - test(b -> b.addBranch(1), new byte[]{Bytecode.BR_U8, 0x00}); + test(b -> b.addBranch(1, RuntimeBytecodeGen.BranchOp.BR), new byte[]{Bytecode.BR_U8, 0x00}); } @Test @@ -171,18 +171,18 @@ public void testBrU8Max() { for (int i = 0; i < 254; i++) { b.addOp(0); } - b.addBranch(0); + b.addBranch(0, RuntimeBytecodeGen.BranchOp.BR); }, expected); } @Test public void testBrI32MinForward() { - test(b -> b.addBranch(2), new byte[]{Bytecode.BR_I32, 0x01, 0x00, 0x00, 0x00}); + test(b -> b.addBranch(2, RuntimeBytecodeGen.BranchOp.BR), new byte[]{Bytecode.BR_I32, 0x01, 0x00, 0x00, 0x00}); } @Test public void testBrI32MaxForward() { - test(b -> b.addBranch(2147483647), new byte[]{Bytecode.BR_I32, (byte) 0xFE, (byte) 0xFF, (byte) 0xFF, 0x7F}); + test(b -> b.addBranch(2147483647, RuntimeBytecodeGen.BranchOp.BR), new byte[]{Bytecode.BR_I32, (byte) 0xFE, (byte) 0xFF, (byte) 0xFF, 0x7F}); } @Test @@ -197,13 +197,13 @@ public void testBrI32MinBackward() { for (int i = 0; i < 255; i++) { b.addOp(0); } - b.addBranch(0); + b.addBranch(0, RuntimeBytecodeGen.BranchOp.BR); }, expected); } @Test public void testBrIfU8Min() { - test(b -> b.addBranchIf(1), new byte[]{Bytecode.BR_IF_U8, 0x00, 0x00, 0x00}); + test(b -> b.addBranch(1, RuntimeBytecodeGen.BranchOp.BR_IF), new byte[]{Bytecode.BR_IF_U8, 0x00, 0x00, 0x00}); } @Test @@ -215,18 +215,18 @@ public void testBrIfU8Max() { for (int i = 0; i < 254; i++) { b.addOp(0); } - b.addBranchIf(0); + b.addBranch(0, RuntimeBytecodeGen.BranchOp.BR_IF); }, expected); } @Test public void testBrIfI32MinForward() { - test(b -> b.addBranchIf(2), new byte[]{Bytecode.BR_IF_I32, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}); + test(b -> b.addBranch(2, RuntimeBytecodeGen.BranchOp.BR_IF), new byte[]{Bytecode.BR_IF_I32, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00}); } @Test public void testBrIfI32MaxForward() { - test(b -> b.addBranchIf(2147483647), new byte[]{Bytecode.BR_IF_I32, (byte) 0xFE, (byte) 0xFF, (byte) 0xFF, 0x7F, 0x00, 0x00}); + test(b -> b.addBranch(2147483647, RuntimeBytecodeGen.BranchOp.BR_IF), new byte[]{Bytecode.BR_IF_I32, (byte) 0xFE, (byte) 0xFF, (byte) 0xFF, 0x7F, 0x00, 0x00}); } @Test @@ -241,7 +241,7 @@ public void testBrIfI32MinBackward() { for (int i = 0; i < 255; i++) { b.addOp(0); } - b.addBranchIf(0); + b.addBranch(0, RuntimeBytecodeGen.BranchOp.BR_IF); }, expected); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java index a4d3eb8441c2..b5212dc19a57 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java @@ -356,7 +356,7 @@ public int addLoopLabel(int resultCount, int stackSize, int commonResultType) { /** * Adds an if opcode to the bytecode and reserves an i32 value for the jump offset and a 2-byte * profile. - * + * * @return The location of the jump offset to be patched later. (see * {@link #patchLocation(int, int)}. */ @@ -370,139 +370,82 @@ public int addIfLocation() { return location; } - /** - * Adds a branch opcode to the bytecode. If the negative jump offset fits into a u8 value, a - * br_u8 and u8 jump offset is added (The jump offset is encoded as a positive value). - * Otherwise, a br_i32 and i32 jump offset is added. - * - * @param location The target location of the branch. - */ - public void addBranch(int location) { - assert location >= 0; - final int relativeOffset = location - (location() + 1); - if (relativeOffset <= 0 && relativeOffset >= -255) { - add1(Bytecode.BR_U8); - add1(-relativeOffset); - } else { - add1(Bytecode.BR_I32); - add4(relativeOffset); + public enum BranchOp { + BR(op(Bytecode.BR_U8), op(Bytecode.BR_I32), false), + BR_IF(op(Bytecode.BR_IF_U8), op(Bytecode.BR_IF_I32), true), + BR_ON_NULL(miscOp(Bytecode.BR_ON_NULL_U8), miscOp(Bytecode.BR_ON_NULL_I32), true), + BR_ON_NON_NULL(miscOp(Bytecode.BR_ON_NON_NULL_U8), miscOp(Bytecode.BR_ON_NON_NULL_I32), true); + + private final byte[] opcodesU8; + private final byte[] opcodesI32; + private final boolean profiled; + + BranchOp(byte[] opcodesU8, byte[] opcodesI32, boolean profiled) { + this.opcodesU8 = opcodesU8; + this.opcodesI32 = opcodesI32; + this.profiled = profiled; } - } - /** - * Adds a br_i32 instruction to the bytecode and reserves an i32 value for the jump offset. - * - * @return The location of the jump offset to be patched later. (see - * {@link #patchLocation(int, int)}). - */ - public int addBranchLocation() { - add1(Bytecode.BR_I32); - final int location = location(); - add4(0); - return location; + public void emitOpcodesU8(RuntimeBytecodeGen bytecode) { + bytecode.addBytes(opcodesU8, 0, opcodesU8.length); + } + + public void emitOpcodesI32(RuntimeBytecodeGen bytecode) { + bytecode.addBytes(opcodesI32, 0, opcodesI32.length); + } + + public void emitProfile(RuntimeBytecodeGen bytecode) { + if (profiled) { + bytecode.addProfile(); + } + } + + private static byte[] op(int opcode) { + return new byte[]{(byte) opcode}; + } + + private static byte[] miscOp(int opcode) { + return new byte[]{(byte) Bytecode.MISC, (byte) opcode}; + } } /** - * Adds a conditional branch opcode to the bytecode. If the jump offset fits into a signed i8 - * value, a br_if_i8 and i8 jump offset is added. Otherwise, a br_if_i32 and i32 jump offset is - * added. In both cases, a profile with a size of 2-byte is added. - * + * Adds a branch opcode to the bytecode. If the jump offset fits into a signed i8 value, + * {@code opcodesU8} and i8 jump offset is added. Otherwise, {@code opcodesI32} and i32 jump + * offset is added. In both cases, a profile with a size of 2-byte is added. + * * @param location The target location of the branch. */ - public void addBranchIf(int location) { + public void addBranch(int location, BranchOp branchOp) { assert location >= 0; final int relativeOffset = location - (location() + 1); if (relativeOffset <= 0 && relativeOffset >= -255) { - add1(Bytecode.BR_IF_U8); + branchOp.emitOpcodesU8(this); // target add1(-relativeOffset); - // profile - addProfile(); } else { - add1(Bytecode.BR_IF_I32); + branchOp.emitOpcodesI32(this); // target add4(relativeOffset); - // profile - addProfile(); } + // profile + branchOp.emitProfile(this); } /** - * Adds a br_if_i32 opcode to the bytecode and reserves an i32 value for the jump offset. In + * Adds a branch opcode to the bytecode and reserves an i32 value for the jump offset. In * addition, a profile with a size of 2-byte is added. - * + * * @return The location of the jump offset to be patched later. (see * {@link #patchLocation(int, int)}) */ - public int addBranchIfLocation() { - add1(Bytecode.BR_IF_I32); + public int addBranchLocation(BranchOp branchOp) { + branchOp.emitOpcodesI32(this); final int location = location(); // target add4(0); // profile - addProfile(); - return location; - } - - public void addBranchOnNull(int location) { - assert location >= 0; - final int relativeOffset = location - (location() + 1); - if (relativeOffset <= 0 && relativeOffset >= -255) { - add1(Bytecode.MISC); - add1(Bytecode.BR_ON_NULL_U8); - // target - add1(-relativeOffset); - // profile - addProfile(); - } else { - add1(Bytecode.MISC); - add1(Bytecode.BR_ON_NULL_I32); - // target - add4(relativeOffset); - // profile - addProfile(); - } - } - - public int addBranchOnNullLocation() { - add1(Bytecode.MISC); - add1(Bytecode.BR_ON_NULL_I32); - final int location = location(); - // target - add4(0); - // profile - addProfile(); - return location; - } - - public void addBranchOnNonNull(int location) { - assert location >= 0; - final int relativeOffset = location - (location() + 1); - if (relativeOffset <= 0 && relativeOffset >= -255) { - add1(Bytecode.MISC); - add1(Bytecode.BR_ON_NON_NULL_U8); - // target - add1(-relativeOffset); - // profile - addProfile(); - } else { - add1(Bytecode.MISC); - add1(Bytecode.BR_ON_NON_NULL_I32); - // target - add4(relativeOffset); - // profile - addProfile(); - } - } - - public int addBranchOnNonNullLocation() { - add1(Bytecode.MISC); - add1(Bytecode.BR_ON_NON_NULL_I32); - final int location = location(); - // target - add4(0); - // profile - addProfile(); + branchOp.emitProfile(this); return location; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java index 4dfb545a06d8..fdfba38a1d2c 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java @@ -102,23 +102,8 @@ void exit(RuntimeBytecodeGen bytecode) { } @Override - void addBranch(RuntimeBytecodeGen bytecode) { - branches.add(bytecode.addBranchLocation()); - } - - @Override - void addBranchIf(RuntimeBytecodeGen bytecode) { - branches.add(bytecode.addBranchIfLocation()); - } - - @Override - void addBranchOnNull(RuntimeBytecodeGen bytecode) { - branches.add(bytecode.addBranchOnNullLocation()); - } - - @Override - void addBranchOnNonNull(RuntimeBytecodeGen bytecode) { - branches.add(bytecode.addBranchOnNonNullLocation()); + void addBranch(RuntimeBytecodeGen bytecode, RuntimeBytecodeGen.BranchOp branchOp) { + branches.add(bytecode.addBranchLocation(branchOp)); } @Override diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java index f6cc06575800..a8a211627e5d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java @@ -140,25 +140,12 @@ void initializeLocal(int localIndex) { abstract void exit(RuntimeBytecodeGen bytecode); /** - * Adds an unconditional branch targeting this control frame. Automatically patches the branch - * target as soon as it is available. - * - * @param bytecode The bytecode of the current control frame. - */ - abstract void addBranch(RuntimeBytecodeGen bytecode); - - /** - * Adds a conditional branch targeting this control frame. Automatically patches the branch - * target as soon as it is available. + * Adds a branch targeting this control frame. Automatically patches the branch target as soon + * as it is available. * * @param bytecode The bytecode of the current control frame. */ - - abstract void addBranchIf(RuntimeBytecodeGen bytecode); - - abstract void addBranchOnNull(RuntimeBytecodeGen bytecode); - - abstract void addBranchOnNonNull(RuntimeBytecodeGen bytecode); + abstract void addBranch(RuntimeBytecodeGen bytecode, RuntimeBytecodeGen.BranchOp branchOp); /** * Adds a branch table item targeting this control frame. Automatically patches the branch diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java index 672d4ea05678..1dc4ec731a13 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java @@ -78,7 +78,7 @@ int[] labelTypes() { @Override void enterElse(ParserState state, RuntimeBytecodeGen bytecode) { initializedLocals = (BitSet) parentFrame.initializedLocals.clone(); - final int location = bytecode.addBranchLocation(); + final int location = bytecode.addBranchLocation(RuntimeBytecodeGen.BranchOp.BR); bytecode.patchLocation(falseJumpLocation, bytecode.location()); falseJumpLocation = location; elseBranch = true; @@ -107,23 +107,8 @@ void exit(RuntimeBytecodeGen bytecode) { } @Override - void addBranch(RuntimeBytecodeGen bytecode) { - branchTargets.add(bytecode.addBranchLocation()); - } - - @Override - void addBranchIf(RuntimeBytecodeGen bytecode) { - branchTargets.add(bytecode.addBranchIfLocation()); - } - - @Override - void addBranchOnNull(RuntimeBytecodeGen bytecode) { - branchTargets.add(bytecode.addBranchOnNullLocation()); - } - - @Override - void addBranchOnNonNull(RuntimeBytecodeGen bytecode) { - branchTargets.add(bytecode.addBranchOnNonNullLocation()); + void addBranch(RuntimeBytecodeGen bytecode, RuntimeBytecodeGen.BranchOp branchOp) { + branchTargets.add(bytecode.addBranchLocation(branchOp)); } @Override diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java index 2596f4a2da83..3aae0d9cfb3c 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java @@ -73,23 +73,8 @@ void exit(RuntimeBytecodeGen bytecode) { } @Override - void addBranch(RuntimeBytecodeGen bytecode) { - bytecode.addBranch(labelLocation); - } - - @Override - void addBranchIf(RuntimeBytecodeGen bytecode) { - bytecode.addBranchIf(labelLocation); - } - - @Override - void addBranchOnNull(RuntimeBytecodeGen bytecode) { - bytecode.addBranchOnNull(labelLocation); - } - - @Override - void addBranchOnNonNull(RuntimeBytecodeGen bytecode) { - bytecode.addBranchOnNonNull(labelLocation); + void addBranch(RuntimeBytecodeGen bytecode, RuntimeBytecodeGen.BranchOp branchOp) { + bytecode.addBranch(labelLocation, branchOp); } @Override diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java index cda383b27bcc..2c8caa2e3608 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java @@ -401,7 +401,7 @@ public void addConditionalBranch(int branchLabel) { final int[] labelTypes = frame.labelTypes(); popAll(labelTypes); pushAll(labelTypes); - frame.addBranchIf(bytecode); + frame.addBranch(bytecode, RuntimeBytecodeGen.BranchOp.BR_IF); } /** @@ -415,7 +415,7 @@ public void addUnconditionalBranch(int branchLabel) { ControlFrame frame = getFrame(branchLabel); final int[] labelTypes = frame.labelTypes(); popAll(labelTypes); - frame.addBranch(bytecode); + frame.addBranch(bytecode, RuntimeBytecodeGen.BranchOp.BR); } public void addBranchOnNull(int branchLabel) { @@ -424,7 +424,7 @@ public void addBranchOnNull(int branchLabel) { final int[] labelTypes = frame.labelTypes(); popAll(labelTypes); pushAll(labelTypes); - frame.addBranchOnNull(bytecode); + frame.addBranch(bytecode, RuntimeBytecodeGen.BranchOp.BR_ON_NULL); } public void addBranchOnNonNull(int branchLabel, int referenceType) { @@ -443,7 +443,7 @@ public void addBranchOnNonNull(int branchLabel, int referenceType) { for (int i = 0; i < labelTypes.length - 1; i++) { push(labelTypes[i]); } - frame.addBranchOnNonNull(bytecode); + frame.addBranch(bytecode, RuntimeBytecodeGen.BranchOp.BR_ON_NON_NULL); } /** From a11ef3f3028278846781ed843375ef04009c6d10 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 7 Oct 2025 12:22:33 +0200 Subject: [PATCH 11/40] Fix immutable global error message --- .../src/org/graalvm/wasm/exception/Failure.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java index a1e31b6f0477..2ffd9bc311c7 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java @@ -93,7 +93,7 @@ public enum Failure { START_FUNCTION_PARAMS(Type.INVALID, "start function"), LIMIT_MINIMUM_GREATER_THAN_MAXIMUM(Type.INVALID, "size minimum must not be greater than maximum"), DUPLICATE_EXPORT(Type.INVALID, "duplicate export name"), - IMMUTABLE_GLOBAL_WRITE(Type.INVALID, "global is immutable"), + IMMUTABLE_GLOBAL_WRITE(Type.INVALID, "immutable global"), CONSTANT_EXPRESSION_REQUIRED(Type.INVALID, "constant expression required"), LIMIT_EXCEEDED(Type.INVALID, "limit exceeded"), MEMORY_SIZE_LIMIT_EXCEEDED(Type.INVALID, "memory size must be at most 65536 pages (4GiB)"), From feab7381a6792012c7ee0ecdde1cfd7ecd3028bb Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 7 Oct 2025 12:23:11 +0200 Subject: [PATCH 12/40] Fix malformed limits flags failure type --- .../src/org/graalvm/wasm/BinaryParser.java | 18 +++--------------- .../org/graalvm/wasm/exception/Failure.java | 1 + 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index 1559b020153f..bb708842ac28 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -3457,11 +3457,7 @@ private void readLimits(int[] out, int max) { break; } default: - if (limitsPrefix < 0) { - fail(Failure.INTEGER_REPRESENTATION_TOO_LONG, "Invalid limits prefix (expected 0x00 or 0x01, got 0x%02X)", limitsPrefix); - } else { - fail(Failure.INTEGER_TOO_LARGE, "Invalid limits prefix (expected 0x00 or 0x01, got 0x%02X)", limitsPrefix); - } + fail(Failure.MALFORMED_LIMITS_FLAGS, "Invalid limits prefix (expected 0x00 or 0x01, got 0x%02X)", limitsPrefix); } } @@ -3498,11 +3494,7 @@ private void readLongLimits(long[] longOut, boolean[] boolOut, int max32Bit, lon } default: { if (!threads) { - if (limitsPrefix < 0) { - fail(Failure.INTEGER_REPRESENTATION_TOO_LONG, "Invalid limits prefix (expected 0x00, 0x01, 0x04, or 0x05, got 0x%02X)", limitsPrefix); - } else { - fail(Failure.INTEGER_TOO_LARGE, "Invalid limits prefix (expected 0x00, 0x01, 0x04, or 0x05, got 0x%02X)", limitsPrefix); - } + fail(Failure.MALFORMED_LIMITS_FLAGS, "Invalid limits prefix (expected 0x00, 0x01, 0x04, or 0x05, got 0x%02X)", limitsPrefix); } else { switch (limitsPrefix) { case 0x02: @@ -3525,11 +3517,7 @@ private void readLongLimits(long[] longOut, boolean[] boolOut, int max32Bit, lon break; } default: - if (limitsPrefix < 0) { - fail(Failure.INTEGER_REPRESENTATION_TOO_LONG, "Invalid limits prefix (expected 0x00-0x07, got 0x%02X)", limitsPrefix); - } else { - fail(Failure.INTEGER_TOO_LARGE, "Invalid limits prefix (expected 0x00-0x07, got 0x%02X)", limitsPrefix); - } + fail(Failure.MALFORMED_LIMITS_FLAGS, "Invalid limits prefix (expected 0x00-0x07, got 0x%02X)", limitsPrefix); } } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java index 2ffd9bc311c7..dc1782cf75b3 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java @@ -66,6 +66,7 @@ public enum Failure { MALFORMED_IMPORT_KIND(Type.MALFORMED, "malformed import kind"), END_OPCODE_EXPECTED(Type.MALFORMED, "END opcode expected"), UNEXPECTED_CONTENT_AFTER_LAST_SECTION(Type.MALFORMED, "unexpected content after last section"), + MALFORMED_LIMITS_FLAGS(Type.MALFORMED, "malformed limits flags"), MALFORMED_MEMOP_FLAGS(Type.MALFORMED, "malformed memop flags"), MALFORMED_CATCH(Type.MALFORMED, "malformed catch clause"), MALFORMED_TAG_ATTRIBUTE(Type.MALFORMED, "malformed tag attribute"), From 1885228f10b94a9697519c7b2165673b8e92a4d6 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 7 Oct 2025 12:23:28 +0200 Subject: [PATCH 13/40] Fix unexpected end of const expr failure message --- .../src/org/graalvm/wasm/BinaryParser.java | 51 ++++++++++--------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index bb708842ac28..27ef79b0f9f8 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -193,6 +193,7 @@ private void readSymbolSections() { final int size = readLength(); final int startOffset = offset; + final int endOffset = startOffset + size; switch (sectionID) { case Section.CUSTOM: readCustomSection(size, customData); @@ -207,7 +208,7 @@ private void readSymbolSections() { readFunctionSection(); break; case Section.TABLE: - readTableSection(); + readTableSection(endOffset); break; case Section.MEMORY: readMemorySection(); @@ -216,7 +217,7 @@ private void readSymbolSections() { readTagSection(); break; case Section.GLOBAL: - readGlobalSection(); + readGlobalSection(endOffset); break; case Section.EXPORT: readExportSection(); @@ -225,7 +226,7 @@ private void readSymbolSections() { readStartSection(); break; case Section.ELEMENT: - readElementSection(bytecode); + readElementSection(bytecode, endOffset); break; case Section.DATA_COUNT: if (bulkMemoryAndRefTypes) { @@ -241,7 +242,7 @@ private void readSymbolSections() { break; case Section.DATA: dataSectionPresent = true; - readDataSection(bytecode); + readDataSection(bytecode, endOffset); break; } assertIntEqual(offset - startOffset, size, Failure.SECTION_SIZE_MISMATCH, "Declared section (0x%02X) size is incorrect", sectionID); @@ -490,7 +491,7 @@ private void readFunctionSection() { } } - private void readTableSection() { + private void readTableSection(int endOffset) { final int tableCount = readLength(); final int startingTableIndex = module.tableCount(); module.limits().checkTableCount(startingTableIndex + tableCount); @@ -504,7 +505,7 @@ private void readTableSection() { offset += 2; elemType = readRefType(); readTableLimits(multiResult); - Pair initExpression = readConstantExpression(elemType); + Pair initExpression = readConstantExpression(elemType, endOffset); initValue = initExpression.getLeft(); // Drop the initializer bytecode if we can eval the initializer during parsing initBytecode = initValue == null ? initExpression.getRight() : null; @@ -2651,11 +2652,11 @@ private ExceptionHandler[] readExceptionHandlers(ParserState state) { return handlers; } - private Pair readOffsetExpression() { + private Pair readOffsetExpression(int endOffset) { // Table offset expression must be a constant expression with result type i32. // https://webassembly.github.io/spec/core/syntax/modules.html#element-segments // https://webassembly.github.io/spec/core/valid/instructions.html#constant-expressions - Pair result = readConstantExpression(I32_TYPE); + Pair result = readConstantExpression(I32_TYPE, endOffset); if (result.getRight() == null) { return Pair.create((int) result.getLeft(), null); } else { @@ -2663,8 +2664,8 @@ private Pair readOffsetExpression() { } } - private Pair readLongOffsetExpression() { - Pair result = readConstantExpression(I64_TYPE); + private Pair readLongOffsetExpression(int endOffset) { + Pair result = readConstantExpression(I64_TYPE, endOffset); if (result.getRight() == null) { return Pair.create((long) result.getLeft(), null); } else { @@ -2672,7 +2673,7 @@ private Pair readLongOffsetExpression() { } } - private Pair readConstantExpression(int resultType) { + private Pair readConstantExpression(int resultType, int endOffset) { // Read the constant expression. // https://webassembly.github.io/spec/core/valid/instructions.html#constant-expressions final RuntimeBytecodeGen bytecode = new RuntimeBytecodeGen(); @@ -2682,8 +2683,9 @@ private Pair readConstantExpression(int resultType) { boolean calculable = true; state.enterFunction(EMPTY_TYPES, new int[]{resultType}, EMPTY_TYPES); - int opcode; - while ((opcode = read1() & 0xFF) != Instructions.END) { + int opcode = -1; + read_loop: while (offset < endOffset) { + opcode = read1() & 0xFF; switch (opcode) { case Instructions.I32_CONST: { final int value = readSignedInt32(); @@ -2811,11 +2813,14 @@ private Pair readConstantExpression(int resultType) { break; } break; + case Instructions.END: + break read_loop; default: fail(Failure.ILLEGAL_OPCODE, "Invalid instruction for constant expression: 0x%02X", opcode); break; } } + Assert.assertTrue(opcode == Instructions.END, Failure.UNEXPECTED_END); assertIntEqual(state.valueStackSize(), 1, Failure.TYPE_MISMATCH, "Unexpected number of results on stack at constant expression end"); state.exit(multiValue); if (calculable) { @@ -2906,7 +2911,7 @@ private long[] readElemExpressions(int elemType) { return elements; } - private void readElementSection(RuntimeBytecodeGen bytecode) { + private void readElementSection(RuntimeBytecodeGen bytecode, int endOffset) { int elemSegmentCount = readLength(); module.limits().checkElementSegmentCount(elemSegmentCount); for (int elemSegmentIndex = 0; elemSegmentIndex != elemSegmentCount; elemSegmentIndex++) { @@ -2929,7 +2934,7 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { } else { tableIndex = 0; } - Pair offsetExpression = readOffsetExpression(); + Pair offsetExpression = readOffsetExpression(endOffset); currentOffsetAddress = offsetExpression.getLeft(); currentOffsetBytecode = offsetExpression.getRight(); } else { @@ -2957,7 +2962,7 @@ private void readElementSection(RuntimeBytecodeGen bytecode) { } else { mode = SegmentMode.ACTIVE; tableIndex = readTableIndex(); - Pair offsetExpression = readOffsetExpression(); + Pair offsetExpression = readOffsetExpression(endOffset); currentOffsetAddress = offsetExpression.getLeft(); currentOffsetBytecode = offsetExpression.getRight(); elemType = FUNCREF_TYPE; @@ -3070,7 +3075,7 @@ private void readTagSection() { } } - private void readGlobalSection() { + private void readGlobalSection(int endOffset) { final int globalCount = readLength(); module.limits().checkGlobalCount(globalCount); final int startingGlobalIndex = module.symbolTable().numGlobals(); @@ -3081,7 +3086,7 @@ private void readGlobalSection() { final byte mutability = readMutability(); // Global initialization expressions must be constant expressions: // https://webassembly.github.io/spec/core/valid/instructions.html#constant-expressions - Pair initExpression = readConstantExpression(type); + Pair initExpression = readConstantExpression(type, endOffset); final Object initValue = initExpression.getLeft(); final byte[] initBytecode = initExpression.getRight(); final boolean isInitialized = initBytecode == null; @@ -3098,7 +3103,7 @@ private void readDataCountSection(int size) { } } - private void readDataSection(RuntimeBytecodeGen bytecode) { + private void readDataSection(RuntimeBytecodeGen bytecode, int endOffset) { final int dataSegmentCount = readLength(); module.limits().checkDataSegmentCount(dataSegmentCount); if (bulkMemoryAndRefTypes) { @@ -3128,11 +3133,11 @@ private void readDataSection(RuntimeBytecodeGen bytecode) { if (mode == SegmentMode.ACTIVE) { checkMemoryIndex(memoryIndex); if (module.memoryHasIndexType64(memoryIndex)) { - Pair offsetExpression = readLongOffsetExpression(); + Pair offsetExpression = readLongOffsetExpression(endOffset); offsetAddress = offsetExpression.getLeft(); offsetBytecode = offsetExpression.getRight(); } else { - Pair offsetExpression = readOffsetExpression(); + Pair offsetExpression = readOffsetExpression(endOffset); offsetAddress = offsetExpression.getLeft(); offsetBytecode = offsetExpression.getRight(); } @@ -3149,11 +3154,11 @@ private void readDataSection(RuntimeBytecodeGen bytecode) { memoryIndex = 0; } if (module.memoryHasIndexType64(memoryIndex)) { - Pair offsetExpression = readLongOffsetExpression(); + Pair offsetExpression = readLongOffsetExpression(endOffset); offsetAddress = offsetExpression.getLeft(); offsetBytecode = offsetExpression.getRight(); } else { - Pair offsetExpression = readOffsetExpression(); + Pair offsetExpression = readOffsetExpression(endOffset); offsetAddress = offsetExpression.getLeft(); offsetBytecode = offsetExpression.getRight(); } From 68d6b789d0a11cc77c86a932ef7f63ed1ab1fd83 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 7 Oct 2025 12:33:31 +0200 Subject: [PATCH 14/40] Fix failure type reported for illegal opcodes --- .../org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index 27ef79b0f9f8..840f5211a377 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -2445,11 +2445,11 @@ private void readNumericInstructions(ParserState state, int opcode) { state.addInstruction(vectorOpcodeToBytecode(vectorOpcode)); break; default: - fail(Failure.UNSPECIFIED_MALFORMED, "Unknown opcode: 0xFD 0x%02x", vectorOpcode); + fail(Failure.ILLEGAL_OPCODE, "Unknown opcode: 0xFD 0x%02x", vectorOpcode); } break; default: - fail(Failure.UNSPECIFIED_MALFORMED, "Unknown opcode: 0x%02x", opcode); + fail(Failure.ILLEGAL_OPCODE, "Unknown opcode: 0x%02x", opcode); break; } } From 0654ce2a490c6ad562d44914f9d9f4dff78e42af Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 7 Oct 2025 20:11:06 +0200 Subject: [PATCH 15/40] Use TYPE_MISMATCH failure for uninitialized tables --- .../src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java | 2 +- .../src/org/graalvm/wasm/exception/Failure.java | 1 - .../src/org/graalvm/wasm/predefined/BuiltinModule.java | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index 840f5211a377..8d4c39bb94d7 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -514,7 +514,7 @@ private void readTableSection(int endOffset) { readTableLimits(multiResult); initValue = null; initBytecode = null; - Assert.assertTrue(WasmType.isNullable(elemType), Failure.UNINITIALIZED_TABLE); + Assert.assertTrue(WasmType.isNullable(elemType), "uninitialized table of non-nullable element type", Failure.TYPE_MISMATCH); } module.symbolTable().declareTable(tableIndex, multiResult[0], multiResult[1], elemType, initBytecode, initValue, bulkMemoryAndRefTypes); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java index dc1782cf75b3..35e52721ec91 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java @@ -108,7 +108,6 @@ public enum Failure { UNKNOWN_REFERENCE(Type.INVALID, "unknown reference"), UNDECLARED_FUNCTION_REFERENCE(Type.INVALID, "undeclared function reference"), UNKNOWN_TAG(Type.INVALID, "unknown tag"), - UNINITIALIZED_TABLE(Type.INVALID, "uninitialized table of non-nullable element type"), // GraalWasm-specific: MODULE_SIZE_LIMIT_EXCEEDED(Type.INVALID, "module size exceeds limit"), diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java index e3a2196e537d..b9afdc82bd81 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/predefined/BuiltinModule.java @@ -122,7 +122,7 @@ protected int defineTable(WasmContext context, WasmModule module, String tableNa } else if (!referenceTypes && type != WasmType.FUNCREF_TYPE) { throw WasmException.create(Failure.UNSPECIFIED_MALFORMED, "Only function types are currently supported in tables."); } else if (!WasmType.isNullable(type)) { - throw WasmException.create(Failure.UNINITIALIZED_TABLE, "Tables of built-in modules must be nullable."); + throw WasmException.create(Failure.TYPE_MISMATCH, "Tables of built-in modules must be nullable."); } int index = module.symbolTable().tableCount(); module.symbolTable().declareTable(index, initSize, maxSize, type, null, null, referenceTypes); From df4e3bb2b80b52d44e9044e79308da35c7e99fde Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 7 Oct 2025 20:12:03 +0200 Subject: [PATCH 16/40] Re-run table initializer during table state reset --- wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java | 2 +- .../org/graalvm/wasm/parser/bytecode/BytecodeParser.java | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index e8413bf152bb..54f832b879b1 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -791,7 +791,7 @@ void resolvePassiveDataSegment(WasmStore store, WasmInstance instance, int dataS resolutionDag.resolveLater(new DataSym(instance.name(), dataSegmentId), dependencies.toArray(new Sym[0]), resolveAction); } - private static void initializeTable(WasmInstance instance, int tableIndex, Object initValue) { + public static void initializeTable(WasmInstance instance, int tableIndex, Object initValue) { int tableAddress = instance.tableAddress(tableIndex); WasmTable table = instance.store().tables().table(tableAddress); table.fill(0, table.size(), initValue); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java index 80dc61b2e2fc..cf96af2e9e08 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java @@ -224,6 +224,13 @@ public static void resetMemoryState(WasmModule module, WasmInstance instance) { */ public static void resetTableState(WasmStore store, WasmModule module, WasmInstance instance) { final byte[] bytecode = module.bytecode(); + for (int tableIndex = 0; tableIndex < module.tableCount(); tableIndex++) { + if (module.tableInitialValue(tableIndex) != null) { + Linker.initializeTable(instance, tableIndex, module.tableInitialValue(tableIndex)); + } else if (module.tableInitializerBytecode(tableIndex) != null) { + Linker.initializeTable(instance, tableIndex, Linker.evalConstantExpression(instance, module.tableInitializerBytecode(tableIndex))); + } + } for (int i = 0; i < module.elemInstanceCount(); i++) { final int elemOffset = module.elemInstanceOffset(i); final int flags = bytecode[elemOffset]; From c8d61630bc3f3d20f5d3779fe971f98ba8f34eff Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 7 Oct 2025 20:14:16 +0200 Subject: [PATCH 17/40] Enforce types at the typed wasm / untyped interop boundary Make the type checks around typed functions, global and table writes more precise. --- .../org/graalvm/wasm/test/WasmJsApiSuite.java | 4 +- .../src/org/graalvm/wasm/Linker.java | 11 +- .../src/org/graalvm/wasm/SymbolTable.java | 75 ++++++-- .../org/graalvm/wasm/WasmInstantiator.java | 2 +- .../src/org/graalvm/wasm/WasmTable.java | 19 ++- .../src/org/graalvm/wasm/WasmType.java | 5 +- .../wasm/api/ExecuteHostFunctionNode.java | 27 +-- .../wasm/api/InteropCallAdapterNode.java | 70 +------- .../src/org/graalvm/wasm/api/TableKind.java | 4 +- .../src/org/graalvm/wasm/api/ValueType.java | 37 ++-- .../src/org/graalvm/wasm/api/WebAssembly.java | 161 +++++++----------- .../org/graalvm/wasm/globals/WasmGlobal.java | 139 ++++++++------- .../graalvm/wasm/nodes/WasmFunctionNode.java | 2 +- 13 files changed, 259 insertions(+), 297 deletions(-) diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java index b2a512cdbce0..30fe023253ec 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java @@ -502,7 +502,7 @@ public void testGlobalWriteNull() throws IOException { public void testGlobalWriteAnyfuncRefTypesDisabled() throws IOException { runTest(WasmJsApiSuite::disableRefTypes, context -> { final WebAssembly wasm = new WebAssembly(context); - final WasmGlobal global = new WasmGlobal(ValueType.anyfunc, true, WasmConstant.NULL); + final WasmGlobal global = WasmGlobal.allocRef(ValueType.anyfunc, true, WasmConstant.NULL); try { wasm.globalWrite(global, WasmConstant.NULL); Assert.fail("Should have failed - ref types not enabled"); @@ -516,7 +516,7 @@ public void testGlobalWriteAnyfuncRefTypesDisabled() throws IOException { public void testGlobalWriteExternrefRefTypesDisabled() throws IOException { runTest(WasmJsApiSuite::disableRefTypes, context -> { final WebAssembly wasm = new WebAssembly(context); - final WasmGlobal global = new WasmGlobal(ValueType.externref, true, WasmConstant.NULL); + final WasmGlobal global = WasmGlobal.allocRef(ValueType.externref, true, WasmConstant.NULL); try { wasm.globalWrite(global, WasmConstant.NULL); Assert.fail("Should have failed - ref types not enabled"); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 54f832b879b1..573d3d0e8f3d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -88,7 +88,6 @@ import org.graalvm.wasm.Linker.ResolutionDag.Resolver; import org.graalvm.wasm.Linker.ResolutionDag.Sym; import org.graalvm.wasm.api.ExecuteHostFunctionNode; -import org.graalvm.wasm.api.ValueType; import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.constants.Bytecode; import org.graalvm.wasm.constants.BytecodeBitEncoding; @@ -312,7 +311,7 @@ void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescripto final int exportedValueType; final byte exportedMutability; if (externalGlobal != null) { - exportedValueType = externalGlobal.getValueType().value(); + exportedValueType = externalGlobal.getType(); exportedMutability = externalGlobal.getMutability(); } else { final WasmInstance importedInstance = store.lookupModuleInstance(importedModuleName); @@ -362,7 +361,7 @@ private static void initializeGlobal(WasmInstance instance, int globalIndex, Obj assert !instance.globals().isInitialized(globalIndex) : globalIndex; SymbolTable symbolTable = instance.symbolTable(); if (symbolTable.globalExternal(globalIndex)) { - var global = new WasmGlobal(ValueType.fromValue(symbolTable.globalValueType(globalIndex)), symbolTable.isGlobalMutable(globalIndex), initValue); + var global = new WasmGlobal(symbolTable.globalValueType(globalIndex), symbolTable.isGlobalMutable(globalIndex), instance.symbolTable(), initValue); instance.setExternalGlobal(globalIndex, global); } else { instance.globals().store(symbolTable.globalValueType(globalIndex), symbolTable.globalAddress(globalIndex), initValue); @@ -402,7 +401,7 @@ void resolveFunctionImport(WasmStore store, WasmInstance instance, WasmFunction Object externalFunctionInstance = lookupImportObject(instance, importDescriptor, imports, Object.class); if (externalFunctionInstance != null) { if (externalFunctionInstance instanceof WasmFunctionInstance functionInstance) { - if (!function.closedType().matches(functionInstance.function().closedType())) { + if (!function.closedType().matchesType(functionInstance.function().closedType())) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE); } instance.setTarget(function.index(), functionInstance.target()); @@ -435,7 +434,7 @@ void resolveFunctionImport(WasmStore store, WasmInstance instance, WasmFunction throw WasmException.create(Failure.UNKNOWN_IMPORT, "The imported function '" + function.importedFunctionName() + "', referenced in the module '" + instance.name() + "', does not exist in the imported module '" + function.importedModuleName() + "'."); } - if (!function.closedType().matches(importedFunction.closedType())) { + if (!function.closedType().matchesType(importedFunction.closedType())) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE); } final CallTarget target = importedInstance.target(importedFunction.index()); @@ -539,7 +538,7 @@ void resolveTagImport(WasmStore store, WasmInstance instance, ImportDescriptor i } importedTag = importedInstance.tag(exportedTagIndex); } - Assert.assertTrue(type.matches(importedTag.type()), Failure.INCOMPATIBLE_IMPORT_TYPE); + Assert.assertTrue(type.matchesType(importedTag.type()), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTag(tagIndex, importedTag); }; resolutionDag.resolveLater(new ImportTagSym(instance.name(), importDescriptor, tagIndex), new Sym[]{new ExportTagSym(importedModuleName, importedTagName)}, resolveAction); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index 7de8d1538cf2..f3d923ff4c0b 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -52,10 +52,12 @@ import org.graalvm.collections.EconomicMap; import org.graalvm.collections.EconomicSet; import org.graalvm.collections.MapCursor; +import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.constants.GlobalModifier; import org.graalvm.wasm.constants.ImportIdentifier; import org.graalvm.wasm.exception.Failure; import org.graalvm.wasm.exception.WasmException; +import org.graalvm.wasm.exception.WasmRuntimeException; import org.graalvm.wasm.memory.WasmMemory; import org.graalvm.wasm.memory.WasmMemoryFactory; @@ -92,7 +94,9 @@ public enum Kind { Bottom } - public abstract boolean matches(ClosedValueType valueSubType); + public abstract boolean matchesType(ClosedValueType valueSubType); + + public abstract boolean matchesValue(Object value); public abstract Kind kind(); } @@ -104,7 +108,9 @@ public enum Kind { Function } - public abstract boolean matches(ClosedHeapType heapSubType); + public abstract boolean matchesType(ClosedHeapType heapSubType); + + public abstract boolean matchesValue(Object value); public abstract Kind kind(); } @@ -126,10 +132,21 @@ public int value() { } @Override - public boolean matches(ClosedValueType valueSubType) { + public boolean matchesType(ClosedValueType valueSubType) { return valueSubType == BottomType.BOTTOM || valueSubType == this; } + @Override + public boolean matchesValue(Object val) { + return switch (value()) { + case WasmType.I32_TYPE -> val instanceof Integer; + case WasmType.I64_TYPE -> val instanceof Long; + case WasmType.F32_TYPE -> val instanceof Float; + case WasmType.F64_TYPE -> val instanceof Double; + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } + @Override public Kind kind() { return Kind.Number; @@ -150,10 +167,15 @@ public int value() { } @Override - public boolean matches(ClosedValueType valueSubType) { + public boolean matchesType(ClosedValueType valueSubType) { return valueSubType == BottomType.BOTTOM || valueSubType == this; } + @Override + public boolean matchesValue(Object val) { + return val instanceof Vector128; + } + @Override public Kind kind() { return Kind.Vector; @@ -185,9 +207,14 @@ public ClosedHeapType heapType() { } @Override - public boolean matches(ClosedValueType valueSubType) { + public boolean matchesType(ClosedValueType valueSubType) { return valueSubType == BottomType.BOTTOM || valueSubType instanceof ClosedReferenceType referenceSubType && (!referenceSubType.nullable || this.nullable) && - this.closedHeapType.matches(referenceSubType.closedHeapType); + this.closedHeapType.matchesType(referenceSubType.closedHeapType); + } + + @Override + public boolean matchesValue(Object value) { + return nullable() && value == WasmConstant.NULL || heapType().matchesValue(value); } @Override @@ -212,7 +239,7 @@ public int value() { } @Override - public boolean matches(ClosedHeapType heapSubType) { + public boolean matchesType(ClosedHeapType heapSubType) { return switch (this.value) { case WasmType.FUNC_HEAPTYPE -> heapSubType == FUNC || heapSubType instanceof ClosedFunctionType; case WasmType.EXTERN_HEAPTYPE -> heapSubType == EXTERN; @@ -221,6 +248,16 @@ public boolean matches(ClosedHeapType heapSubType) { }; } + @Override + public boolean matchesValue(Object val) { + return switch (this.value) { + case WasmType.FUNC_HEAPTYPE -> val instanceof WasmFunctionInstance; + case WasmType.EXTERN_HEAPTYPE -> true; + case WasmType.EXN_HEAPTYPE -> val instanceof WasmRuntimeException; + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } + @Override public Kind kind() { return Kind.Abstract; @@ -245,7 +282,7 @@ public ClosedValueType[] resultTypes() { } @Override - public boolean matches(ClosedHeapType heapSubType) { + public boolean matchesType(ClosedHeapType heapSubType) { if (!(heapSubType instanceof ClosedFunctionType functionSubType)) { return false; } @@ -253,7 +290,7 @@ public boolean matches(ClosedHeapType heapSubType) { return false; } for (int i = 0; i < this.paramTypes.length; i++) { - if (!functionSubType.paramTypes[i].matches(this.paramTypes[i])) { + if (!functionSubType.paramTypes[i].matchesType(this.paramTypes[i])) { return false; } } @@ -261,13 +298,18 @@ public boolean matches(ClosedHeapType heapSubType) { return false; } for (int i = 0; i < this.resultTypes.length; i++) { - if (!this.resultTypes[i].matches(functionSubType.resultTypes[i])) { + if (!this.resultTypes[i].matchesType(functionSubType.resultTypes[i])) { return false; } } return true; } + @Override + public boolean matchesValue(Object value) { + return value instanceof WasmFunctionInstance instance && matchesType(instance.function().closedType()); + } + @Override public Kind kind() { return Kind.Function; @@ -281,10 +323,15 @@ private BottomType() { } @Override - public boolean matches(ClosedValueType valueSubType) { + public boolean matchesType(ClosedValueType valueSubType) { return valueSubType instanceof BottomType; } + @Override + public boolean matchesValue(Object value) { + return false; + } + @Override public Kind kind() { return Kind.Bottom; @@ -849,7 +896,7 @@ yield switch (WasmType.getAbstractHeapType(type)) { } public boolean matches(int expectedType, int actualType) { - return closedTypeAt(expectedType).matches(closedTypeAt(actualType)); + return closedTypeAt(expectedType).matchesType(closedTypeAt(actualType)); } public void importSymbol(ImportDescriptor descriptor) { @@ -1121,9 +1168,9 @@ public void declareTable(int index, int declaredMinSize, int declaredMaxSize, in final WasmTable wasmTable; if (context.getContextOptions().memoryOverheadMode()) { // Initialize an empty table in memory overhead mode. - wasmTable = new WasmTable(0, 0, 0, elemType); + wasmTable = new WasmTable(0, 0, 0, elemType, this); } else { - wasmTable = new WasmTable(declaredMinSize, declaredMaxSize, maxAllowedSize, elemType); + wasmTable = new WasmTable(declaredMinSize, declaredMaxSize, maxAllowedSize, elemType, this); } final int address = store.tables().register(wasmTable); instance.setTableAddress(index, address); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java index 50d3c4ac9e5a..2bc19df720c5 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java @@ -151,7 +151,7 @@ static List recreateLinkActions(WasmModule module) { final ModuleLimits limits = instance.module().limits(); final int maxAllowedSize = WasmMath.minUnsigned(tableMaxSize, limits.tableInstanceSizeLimit()); limits.checkTableInstanceSize(tableMinSize); - final WasmTable wasmTable = new WasmTable(tableMinSize, tableMaxSize, maxAllowedSize, tableElemType); + final WasmTable wasmTable = new WasmTable(tableMinSize, tableMaxSize, maxAllowedSize, tableElemType, module); final int address = store.tables().register(wasmTable); instance.setTableAddress(tableIndex, address); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java index 2473e9f94bea..6b66dbc6d16e 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java @@ -67,6 +67,12 @@ public final class WasmTable extends EmbedderDataHolder implements TruffleObject */ private final int elemType; + /** + * For resolving {@link #elemType} in {@link #closedValueType()}. Can be {@code null} for tables + * allocated from JS. + */ + private final SymbolTable symbolTable; + /** * @see #minSize() */ @@ -86,7 +92,7 @@ public final class WasmTable extends EmbedderDataHolder implements TruffleObject private Object[] elements; @TruffleBoundary - private WasmTable(int declaredMinSize, int declaredMaxSize, int initialSize, int maxAllowedSize, int elemType, Object initialValue) { + private WasmTable(int declaredMinSize, int declaredMaxSize, int initialSize, int maxAllowedSize, int elemType, Object initialValue, SymbolTable symbolTable) { assert compareUnsigned(declaredMinSize, initialSize) <= 0; assert compareUnsigned(initialSize, maxAllowedSize) <= 0; assert compareUnsigned(maxAllowedSize, declaredMaxSize) <= 0; @@ -101,14 +107,15 @@ private WasmTable(int declaredMinSize, int declaredMaxSize, int initialSize, int this.elements = new Object[declaredMinSize]; Arrays.fill(this.elements, initialValue); this.elemType = elemType; + this.symbolTable = symbolTable; } - public WasmTable(int declaredMinSize, int declaredMaxSize, int maxAllowedSize, int elemType) { - this(declaredMinSize, declaredMaxSize, declaredMinSize, maxAllowedSize, elemType, WasmConstant.NULL); + public WasmTable(int declaredMinSize, int declaredMaxSize, int maxAllowedSize, int elemType, SymbolTable symbolTable) { + this(declaredMinSize, declaredMaxSize, declaredMinSize, maxAllowedSize, elemType, WasmConstant.NULL, symbolTable); } public WasmTable(int declaredMinSize, int declaredMaxSize, int maxAllowedSize, int elemType, Object initialValue) { - this(declaredMinSize, declaredMaxSize, declaredMinSize, maxAllowedSize, elemType, initialValue); + this(declaredMinSize, declaredMaxSize, declaredMinSize, maxAllowedSize, elemType, initialValue, null); } /** @@ -163,6 +170,10 @@ public int elemType() { return elemType; } + public SymbolTable.ClosedValueType closedValueType() { + return symbolTable.closedTypeAt(elemType); + } + /** * The current minimum size of the table. The size can change based on calls to * {@link #grow(int, Object)}. diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java index a4b350b2b539..9011d7e587ee 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java @@ -98,8 +98,7 @@ public class WasmType implements TruffleObject { /** * Implementation-specific Types. */ - public static final int TOP = -0x7d; - public static final int NULL_TYPE = -0x7e; + public static final int TOP = -0x7e; public static final int BOT = -0x7f; /** @@ -131,6 +130,8 @@ public static String toString(int valueType) { case F32_TYPE -> "f32"; case F64_TYPE -> "f64"; case V128_TYPE -> "v128"; + case TOP -> "top"; + case BOT -> "bot"; default -> { assert WasmType.isReferenceType(valueType); boolean nullable = WasmType.isNullable(valueType); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java index f4b79c91106f..363d202492ef 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java @@ -40,6 +40,7 @@ */ package org.graalvm.wasm.api; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmArguments; import org.graalvm.wasm.WasmConstant; import org.graalvm.wasm.WasmContext; @@ -127,21 +128,16 @@ public Object execute(VirtualFrame frame) { */ private Object convertResult(Object result, int resultType) throws UnsupportedMessageException { CompilerAsserts.partialEvaluationConstant(resultType); + SymbolTable.ClosedValueType closedResultType = module.closedTypeAt(resultType); + CompilerAsserts.partialEvaluationConstant(closedResultType); return switch (resultType) { case WasmType.I32_TYPE -> asInt(result); case WasmType.I64_TYPE -> asLong(result); case WasmType.F32_TYPE -> asFloat(result); case WasmType.F64_TYPE -> asDouble(result); - case WasmType.V128_TYPE -> { - if (!(result instanceof Vector128)) { - errorBranch.enter(); - throw WasmException.create(Failure.TYPE_MISMATCH); - } - yield result; - } default -> { - assert WasmType.isReferenceType(resultType); - if (!WasmType.isNullable(resultType) && result == WasmConstant.NULL) { + assert WasmType.isVectorType(resultType) || WasmType.isReferenceType(resultType); + if (!closedResultType.matchesValue(result)) { errorBranch.enter(); throw WasmException.create(Failure.TYPE_MISMATCH); } @@ -170,22 +166,17 @@ private void pushMultiValueResult(Object result, int resultCount) { for (int i = 0; i < resultCount; i++) { int resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(resultType); + SymbolTable.ClosedValueType closedResultType = module.closedTypeAt(resultType); + CompilerAsserts.partialEvaluationConstant(closedResultType); Object value = arrayInterop.readArrayElement(result, i); switch (resultType) { case WasmType.I32_TYPE -> primitiveMultiValueStack[i] = asInt(value); case WasmType.I64_TYPE -> primitiveMultiValueStack[i] = asLong(value); case WasmType.F32_TYPE -> primitiveMultiValueStack[i] = Float.floatToRawIntBits(asFloat(value)); case WasmType.F64_TYPE -> primitiveMultiValueStack[i] = Double.doubleToRawLongBits(asDouble(value)); - case WasmType.V128_TYPE -> { - if (!(value instanceof Vector128)) { - errorBranch.enter(); - throw WasmException.create(Failure.INVALID_TYPE_IN_MULTI_VALUE); - } - objectMultiValueStack[i] = value; - } default -> { - assert WasmType.isReferenceType(resultType); - if (!WasmType.isNullable(resultType) && value == WasmConstant.NULL) { + assert WasmType.isVectorType(resultType) || WasmType.isReferenceType(resultType); + if (!closedResultType.matchesValue(result)) { throw WasmException.create(Failure.INVALID_TYPE_IN_MULTI_VALUE); } objectMultiValueStack[i] = value; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java index 04dbb68168f4..f791a52c852f 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java @@ -42,12 +42,10 @@ import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmArguments; -import org.graalvm.wasm.WasmConstant; import org.graalvm.wasm.WasmContext; import org.graalvm.wasm.WasmFunctionInstance; import org.graalvm.wasm.WasmLanguage; import org.graalvm.wasm.WasmType; -import org.graalvm.wasm.exception.WasmRuntimeException; import org.graalvm.wasm.nodes.WasmIndirectCallNode; import com.oracle.truffle.api.CallTarget; @@ -135,73 +133,9 @@ private static void validateArgumentsUnroll(Object[] arguments, int offset, Symb private static void validateArgument(Object[] arguments, int offset, SymbolTable.ClosedValueType[] paramTypes, int i) throws UnsupportedTypeException { SymbolTable.ClosedValueType paramType = paramTypes[i]; Object value = arguments[i + offset]; - switch (paramType.kind()) { - case Number -> { - SymbolTable.NumberType numberType = (SymbolTable.NumberType) paramType; - switch (numberType.value()) { - case WasmType.I32_TYPE -> { - if (value instanceof Integer) { - return; - } - } - case WasmType.I64_TYPE -> { - if (value instanceof Long) { - return; - } - } - case WasmType.F32_TYPE -> { - if (value instanceof Float) { - return; - } - } - case WasmType.F64_TYPE -> { - if (value instanceof Double) { - return; - } - } - } - } - case Vector -> { - if (value instanceof Vector128) { - return; - } - } - case Reference -> { - SymbolTable.ClosedReferenceType referenceType = (SymbolTable.ClosedReferenceType) paramType; - boolean nullable = referenceType.nullable(); - SymbolTable.ClosedHeapType heapType = referenceType.heapType(); - switch (heapType.kind()) { - case Abstract -> { - SymbolTable.AbstractHeapType abstractHeapType = (SymbolTable.AbstractHeapType) heapType; - switch (abstractHeapType.value()) { - case WasmType.FUNCREF_TYPE -> { - if (value instanceof WasmFunctionInstance || nullable && value == WasmConstant.NULL) { - return; - } - } - case WasmType.EXTERNREF_TYPE -> { - if (nullable || value != WasmConstant.NULL) { - return; - } - } - case WasmType.EXNREF_TYPE -> { - if (value instanceof WasmRuntimeException || nullable && value == WasmConstant.NULL) { - return; - } - } - } - } - case Function -> { - SymbolTable.ClosedFunctionType functionType = (SymbolTable.ClosedFunctionType) heapType; - if (value instanceof WasmFunctionInstance instance && functionType.matches(instance.function().closedType()) || nullable && value == WasmConstant.NULL) { - return; - } - } - } - } - case Bottom -> throw CompilerDirectives.shouldNotReachHere(); + if (!paramType.matchesValue(value)) { + throw UnsupportedTypeException.create(arguments); } - throw UnsupportedTypeException.create(arguments); } private Object multiValueStackAsArray(WasmLanguage language) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java index 64f53c46a53f..91ab0519f9d5 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/TableKind.java @@ -44,8 +44,7 @@ public enum TableKind { externref(WasmType.EXTERNREF_TYPE), - anyfunc(WasmType.FUNCREF_TYPE), - exnref(WasmType.EXNREF_TYPE); + anyfunc(WasmType.FUNCREF_TYPE); private final int value; @@ -61,7 +60,6 @@ public static String toString(int value) { return switch (value) { case WasmType.EXTERNREF_TYPE -> "externref"; case WasmType.FUNCREF_TYPE -> "anyfunc"; - case WasmType.EXNREF_TYPE -> "exnref"; default -> ""; }; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java index 5a27fc9a6afb..973ab62a8eab 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java @@ -70,9 +70,18 @@ public static ValueType fromValue(int value) { case WasmType.F32_TYPE -> f32; case WasmType.F64_TYPE -> f64; case WasmType.V128_TYPE -> v128; - case WasmType.FUNCREF_TYPE -> anyfunc; - case WasmType.EXTERNREF_TYPE -> externref; - default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(value)); + default -> { + assert WasmType.isReferenceType(value); + yield switch (WasmType.getAbstractHeapType(value)) { + case WasmType.FUNC_HEAPTYPE -> anyfunc; + case WasmType.EXTERN_HEAPTYPE -> externref; + case WasmType.EXN_HEAPTYPE -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(value)); + default -> { + assert WasmType.isConcreteReferenceType(value); + yield anyfunc; + } + }; + } }; } @@ -93,15 +102,14 @@ public static boolean isReferenceType(ValueType valueType) { } public SymbolTable.ClosedValueType asClosedValueType() { - return switch (value) { - case WasmType.I32_TYPE -> SymbolTable.NumberType.I32; - case WasmType.I64_TYPE -> SymbolTable.NumberType.I64; - case WasmType.F32_TYPE -> SymbolTable.NumberType.F32; - case WasmType.F64_TYPE -> SymbolTable.NumberType.F64; - case WasmType.V128_TYPE -> SymbolTable.VectorType.V128; - case WasmType.FUNCREF_TYPE -> SymbolTable.ClosedReferenceType.FUNCREF; - case WasmType.EXTERNREF_TYPE -> SymbolTable.ClosedReferenceType.EXTERNREF; - default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(value)); + return switch (this) { + case i32 -> SymbolTable.NumberType.I32; + case i64 -> SymbolTable.NumberType.I64; + case f32 -> SymbolTable.NumberType.F32; + case f64 -> SymbolTable.NumberType.F64; + case v128 -> SymbolTable.VectorType.V128; + case anyfunc -> SymbolTable.ClosedReferenceType.FUNCREF; + case externref -> SymbolTable.ClosedReferenceType.EXTERNREF; }; } @@ -111,9 +119,6 @@ public static ValueType fromClosedValueType(SymbolTable.ClosedValueType closedVa case Vector -> fromValue(((SymbolTable.VectorType) closedValueType).value()); case Reference -> { SymbolTable.ClosedReferenceType referenceType = (SymbolTable.ClosedReferenceType) closedValueType; - if (!referenceType.nullable()) { - throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: non-nullable reference"); - } yield switch (referenceType.heapType().kind()) { case Abstract -> { SymbolTable.AbstractHeapType abstractHeapType = (SymbolTable.AbstractHeapType) referenceType.heapType(); @@ -123,7 +128,7 @@ yield switch (abstractHeapType.value()) { default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(abstractHeapType.value())); }; } - case Function -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: typed function reference"); + case Function -> anyfunc; }; } case Bottom -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: bottom"); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java index c9b3f555705e..cae2a1166adb 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java @@ -112,7 +112,7 @@ public WebAssembly(WasmContext currentContext) { addMember("mem_set_wait_callback", new Executable(WebAssembly::memSetWaitCallback)); addMember("global_alloc", new Executable(this::globalAlloc)); - addMember("global_read", new Executable(WebAssembly::globalRead)); + addMember("global_read", new Executable(this::globalRead)); addMember("global_write", new Executable(this::globalWrite)); addMember("tag_alloc", new Executable(WebAssembly::tagAlloc)); @@ -442,10 +442,10 @@ public WasmTable tableAlloc(int initial, int maximum, TableKind elemKind, Object if (Integer.compareUnsigned(initial, JS_LIMITS.tableInstanceSizeLimit()) > 0) { throw new WasmJsApiException(WasmJsApiException.Kind.RangeError, "Min table size exceeds implementation limit"); } - if (elemKind != TableKind.externref && elemKind != TableKind.anyfunc && elemKind != TableKind.exnref) { + if (elemKind != TableKind.externref && elemKind != TableKind.anyfunc) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Element type must be a reftype"); } - if (!refTypes && (elemKind == TableKind.externref || elemKind == TableKind.exnref)) { + if (!refTypes && elemKind == TableKind.externref) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Element type must be anyfunc. Enable wasm.BulkMemoryAndRefTypes to support other reference types"); } final int maxAllowedSize = minUnsigned(maximum, JS_LIMITS.tableInstanceSizeLimit()); @@ -508,23 +508,12 @@ private Object tableWrite(Object[] args) { } public Object tableWrite(WasmTable table, int index, Object element) { - final Object elem; - if (element instanceof WasmFunctionInstance) { - elem = element; - } else if (element == WasmConstant.NULL) { - elem = WasmConstant.NULL; - } else { - if (!currentContext.getContextOptions().supportBulkMemoryAndRefTypes()) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid table element"); - } - if (table.elemType() == WasmType.FUNCREF_TYPE) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid table element"); - } - elem = element; + if (!table.closedValueType().matchesValue(element)) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid table element"); } try { - table.set(index, elem); + table.set(index, element); } catch (ArrayIndexOutOfBoundsException e) { throw new WasmJsApiException(WasmJsApiException.Kind.RangeError, "Table index out of bounds: " + e.getMessage()); } @@ -799,79 +788,68 @@ private Object globalAlloc(Object[] args) { String valueTypeString = lib.asString(args[0]); valueType = ValueType.valueOf(valueTypeString); } catch (UnsupportedMessageException e) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument (value type) must be convertible to String"); + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument (value type) must be convertible to String."); } catch (IllegalArgumentException ex) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type"); + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type."); } final boolean mutable; try { mutable = lib.asBoolean(args[1]); } catch (UnsupportedMessageException e) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument (mutable) must be convertible to boolean"); + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument (mutable) must be convertible to boolean."); } return globalAlloc(valueType, mutable, args[2]); } public WasmGlobal globalAlloc(ValueType valueType, boolean mutable, Object value) { - InteropLibrary valueInterop = InteropLibrary.getUncached(value); - try { - switch (valueType) { - case i32: - return new WasmGlobal(valueType, mutable, valueInterop.asInt(value)); - case i64: - return new WasmGlobal(valueType, mutable, valueInterop.asLong(value)); - case f32: - return new WasmGlobal(valueType, mutable, Float.floatToRawIntBits(valueInterop.asFloat(value))); - case f64: - return new WasmGlobal(valueType, mutable, Double.doubleToRawLongBits(valueInterop.asDouble(value))); - case anyfunc: - if (!refTypes || !(value == WasmConstant.NULL || value instanceof WasmFunctionInstance)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type"); - } - return new WasmGlobal(valueType, mutable, value); - case externref: - if (!refTypes) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type"); - } - return new WasmGlobal(valueType, mutable, value); - default: - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type"); + if (!valueType.asClosedValueType().matchesValue(value)) { + throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); + } + return switch (valueType) { + case i32 -> WasmGlobal.alloc32(valueType, mutable, (int) value); + case i64 -> WasmGlobal.alloc64(valueType, mutable, (long) value); + case f32 -> WasmGlobal.alloc32(valueType, mutable, Float.floatToRawIntBits((float) value)); + case f64 -> WasmGlobal.alloc64(valueType, mutable, Double.doubleToRawLongBits((double) value)); + case v128 -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Accessing v128 globals from JS is not allowed."); + case anyfunc, externref -> { + if (!refTypes) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); + } + yield WasmGlobal.allocRef(valueType, mutable, value); } - } catch (UnsupportedMessageException ex) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Cannot convert value to the specified value type"); - } + }; } - private static Object globalRead(Object[] args) { + private Object globalRead(Object[] args) { checkArgumentCount(args, 1); if (!(args[0] instanceof WasmGlobal global)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be wasm global"); + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be wasm global."); } return globalRead(global); } - public static Object globalRead(WasmGlobal global) { - switch (global.getValueType()) { - case i32: - return global.loadAsInt(); - case i64: - return global.loadAsLong(); - case f32: - return Float.intBitsToFloat(global.loadAsInt()); - case f64: - return Double.longBitsToDouble(global.loadAsLong()); - case anyfunc: - case externref: - return global.loadAsReference(); - - } - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Incorrect internal Global type"); + public Object globalRead(WasmGlobal global) { + int type = global.getType(); + return switch (type) { + case WasmType.I32_TYPE -> global.loadAsInt(); + case WasmType.I64_TYPE -> global.loadAsLong(); + case WasmType.F32_TYPE -> Float.intBitsToFloat(global.loadAsInt()); + case WasmType.F64_TYPE -> Double.longBitsToDouble(global.loadAsLong()); + case WasmType.V128_TYPE -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Accessing v128 globals from JS is not allowed."); + default -> { + assert WasmType.isReferenceType(type); + if (!refTypes) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); + } + yield global.loadAsReference(); + } + }; } private Object globalWrite(Object[] args) { checkArgumentCount(args, 2); if (!(args[0] instanceof WasmGlobal global)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be wasm global"); + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be wasm global."); } return globalWrite(global, args[1]); } @@ -880,48 +858,23 @@ public Object globalWrite(WasmGlobal global, Object value) { if (!global.isMutable()) { throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global is not mutable."); } - ValueType valueType = global.getValueType(); - switch (valueType) { - case i32: - if (!(value instanceof Integer)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); - } - global.storeInt((int) value); - break; - case i64: - if (!(value instanceof Long)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); - } - global.storeLong((long) value); - break; - case f32: - if (!(value instanceof Float)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); - } - global.storeInt(Float.floatToRawIntBits((float) value)); - break; - case f64: - if (!(value instanceof Double)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); - } - global.storeLong(Double.doubleToRawLongBits((double) value)); - break; - case anyfunc: + int type = global.getType(); + if (!global.getClosedValueType().matchesValue(value)) { + throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", ValueType.fromValue(type), value); + } + switch (type) { + case WasmType.I32_TYPE -> global.storeInt((int) value); + case WasmType.I64_TYPE -> global.storeLong((long) value); + case WasmType.F32_TYPE -> global.storeInt(Float.floatToRawIntBits((float) value)); + case WasmType.F64_TYPE -> global.storeLong(Double.doubleToRawLongBits((double) value)); + case WasmType.V128_TYPE -> throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Accessing v128 globals from JS is not allowed."); + default -> { + assert WasmType.isReferenceType(type); if (!refTypes) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled"); - } - if (!(value == WasmConstant.NULL || value instanceof WasmFunctionInstance)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", valueType, value); - } else { - global.storeReference(value); - } - break; - case externref: - if (!refTypes) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled"); + throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); } global.storeReference(value); - break; + } } return WasmConstant.VOID; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java index 8f5b1a1a3edc..cb27b08d2f06 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java @@ -42,9 +42,9 @@ package org.graalvm.wasm.globals; import org.graalvm.wasm.EmbedderDataHolder; -import org.graalvm.wasm.WasmConstant; -import org.graalvm.wasm.WasmFunctionInstance; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmNamesObject; +import org.graalvm.wasm.WasmType; import org.graalvm.wasm.api.ValueType; import org.graalvm.wasm.api.Vector128; import org.graalvm.wasm.constants.GlobalModifier; @@ -62,44 +62,68 @@ @ExportLibrary(InteropLibrary.class) public final class WasmGlobal extends EmbedderDataHolder implements TruffleObject { - private final ValueType valueType; + private final int type; private final boolean mutable; + private final SymbolTable symbolTable; private long globalValue; private Object globalObjectValue; - private WasmGlobal(ValueType valueType, boolean mutable) { - this.valueType = valueType; + private WasmGlobal(int type, boolean mutable, SymbolTable symbolTable) { + this.type = type; this.mutable = mutable; + this.symbolTable = symbolTable; } - public WasmGlobal(ValueType valueType, boolean mutable, int value) { - this(valueType, mutable, (long) value); + public static WasmGlobal alloc32(ValueType valueType, boolean mutable, int value) { + return alloc64(valueType, mutable, value); } - public WasmGlobal(ValueType valueType, boolean mutable, long value) { - this(valueType, mutable); - assert ValueType.isNumberType(getValueType()); - this.globalValue = value; + public static WasmGlobal alloc64(ValueType valueType, boolean mutable, long value) { + assert ValueType.isNumberType(valueType); + WasmGlobal result = new WasmGlobal(valueType.value(), mutable, null); + result.globalValue = value; + return result; + } + + public static WasmGlobal allocRef(ValueType valueType, boolean mutable, Object value) { + assert ValueType.isReferenceType(valueType); + WasmGlobal result = new WasmGlobal(valueType.value(), mutable, null); + result.globalObjectValue = value; + return result; } - public WasmGlobal(ValueType valueType, boolean mutable, Object value) { - this(valueType, mutable); - this.globalValue = switch (valueType) { - case i32 -> (int) value; - case i64 -> (long) value; - case f32 -> Float.floatToRawIntBits((float) value); - case f64 -> Double.doubleToRawLongBits((double) value); + public WasmGlobal(int type, boolean mutable, SymbolTable symbolTable, Object value) { + this(type, mutable, symbolTable); + this.globalValue = switch (type) { + case WasmType.I32_TYPE -> (int) value; + case WasmType.I64_TYPE -> (long) value; + case WasmType.F32_TYPE -> Float.floatToRawIntBits((float) value); + case WasmType.F64_TYPE -> Double.doubleToRawLongBits((double) value); default -> 0; }; - this.globalObjectValue = switch (valueType) { - case v128, anyfunc, externref -> value; - default -> null; - }; + this.globalObjectValue = WasmType.isVectorType(type) || WasmType.isReferenceType(type) ? value : null; + } + + public int getType() { + return type; } - public ValueType getValueType() { - return valueType; + public SymbolTable.ClosedValueType getClosedValueType() { + if (symbolTable != null) { + return symbolTable.closedTypeAt(getType()); + } else { + // Global was created by WebAssembly#global_alloc + return switch (ValueType.fromValue(getType())) { + case i32 -> SymbolTable.NumberType.I32; + case i64 -> SymbolTable.NumberType.I64; + case f32 -> SymbolTable.NumberType.F32; + case f64 -> SymbolTable.NumberType.F64; + case v128 -> SymbolTable.VectorType.V128; + case anyfunc -> SymbolTable.ClosedReferenceType.FUNCREF; + case externref -> SymbolTable.ClosedReferenceType.EXTERNREF; + }; + } } public boolean isMutable() { @@ -111,44 +135,44 @@ public byte getMutability() { } public int loadAsInt() { - assert ValueType.isNumberType(getValueType()); + assert WasmType.isNumberType(getType()); return (int) globalValue; } public long loadAsLong() { - assert ValueType.isNumberType(getValueType()); + assert WasmType.isNumberType(getType()); return globalValue; } public Vector128 loadAsVector128() { - assert ValueType.isVectorType(getValueType()); + assert WasmType.isVectorType(getType()); assert globalObjectValue != null; return (Vector128) globalObjectValue; } public Object loadAsReference() { - assert ValueType.isReferenceType(getValueType()); + assert WasmType.isReferenceType(getType()); assert globalObjectValue != null; return globalObjectValue; } public void storeInt(int value) { - assert ValueType.isNumberType(getValueType()); + assert WasmType.isNumberType(getType()); this.globalValue = value; } public void storeLong(long value) { - assert ValueType.isNumberType(getValueType()); + assert WasmType.isNumberType(getType()); this.globalValue = value; } public void storeVector128(Vector128 value) { - assert ValueType.isVectorType(getValueType()); + assert WasmType.isVectorType(getType()); this.globalObjectValue = value; } public void storeReference(Object value) { - assert ValueType.isReferenceType(getValueType()); + assert WasmType.isReferenceType(getType()); this.globalObjectValue = value; } @@ -172,13 +196,16 @@ Object readMember(String member) throws UnknownIdentifierException { throw UnknownIdentifierException.create(member); } assert VALUE_MEMBER.equals(member) : member; - return switch (getValueType()) { - case i32 -> loadAsInt(); - case i64 -> loadAsLong(); - case f32 -> Float.intBitsToFloat(loadAsInt()); - case f64 -> Double.longBitsToDouble(loadAsLong()); - case v128 -> loadAsVector128(); - case anyfunc, externref -> loadAsReference(); + return switch (getType()) { + case WasmType.I32_TYPE -> loadAsInt(); + case WasmType.I64_TYPE -> loadAsLong(); + case WasmType.F32_TYPE -> Float.intBitsToFloat(loadAsInt()); + case WasmType.F64_TYPE -> Double.longBitsToDouble(loadAsLong()); + case WasmType.V128_TYPE -> loadAsVector128(); + default -> { + assert WasmType.isReferenceType(getType()); + yield loadAsReference(); + } }; } @@ -205,28 +232,24 @@ void writeMember(String member, Object value, // Constant variables cannot be modified after linking. throw UnsupportedMessageException.create(); } - switch (getValueType()) { - case i32 -> storeInt(valueLibrary.asInt(value)); - case i64 -> storeLong(valueLibrary.asLong(value)); - case f32 -> storeInt(Float.floatToRawIntBits(valueLibrary.asFloat(value))); - case f64 -> storeLong(Double.doubleToRawLongBits(valueLibrary.asDouble(value))); - case v128 -> { - if (value instanceof Vector128 vector) { - storeVector128(vector); - } - throw UnsupportedMessageException.create(); - } - case anyfunc -> { - if (value == WasmConstant.NULL || value instanceof WasmFunctionInstance) { - storeReference(value); + switch (type) { + case WasmType.I32_TYPE -> storeInt(valueLibrary.asInt(value)); + case WasmType.I64_TYPE -> storeLong(valueLibrary.asLong(value)); + case WasmType.F32_TYPE -> storeInt(Float.floatToRawIntBits(valueLibrary.asFloat(value))); + case WasmType.F64_TYPE -> storeLong(Double.doubleToRawLongBits(valueLibrary.asDouble(value))); + case WasmType.V128_TYPE -> { + assert WasmType.isVectorType(type); + if (!getClosedValueType().matchesValue(value)) { + throw UnsupportedMessageException.create(); } - throw UnsupportedMessageException.create(); + storeVector128((Vector128) value); } - case externref -> { - if (value instanceof TruffleObject) { - storeReference(value); + default -> { + assert WasmType.isReferenceType(type); + if (!getClosedValueType().matchesValue(value)) { + throw UnsupportedMessageException.create(); } - throw UnsupportedMessageException.create(); + storeReference(value); } } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index 4ff55469a263..52d09248e835 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -587,7 +587,7 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i // Validate that the target function type matches the expected type of the // indirect call. - if (!symtab.closedTypeAt(expectedFunctionTypeIndex).matches(new SymbolTable.ClosedReferenceType(false, function.closedType()))) { + if (!symtab.closedTypeAt(expectedFunctionTypeIndex).matchesType(new SymbolTable.ClosedReferenceType(false, function.closedType()))) { enterErrorBranch(); failFunctionTypeCheck(function, expectedFunctionTypeIndex); } From 26ea7195a1de11ce5f06be6f180134d2403a0362 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Wed, 8 Oct 2025 16:28:49 +0200 Subject: [PATCH 18/40] Use subtype matching for module imports --- .../src/org/graalvm/wasm/Linker.java | 12 ++++++++---- .../src/org/graalvm/wasm/SymbolTable.java | 4 ++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 573d3d0e8f3d..2f730135cda3 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -40,7 +40,6 @@ */ package org.graalvm.wasm; -import static org.graalvm.wasm.Assert.assertIntEqual; import static org.graalvm.wasm.Assert.assertTrue; import static org.graalvm.wasm.Assert.assertUnsignedIntGreaterOrEqual; import static org.graalvm.wasm.Assert.assertUnsignedIntLess; @@ -309,9 +308,11 @@ void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescripto assert instance.module().globalImported(globalIndex) && globalIndex == importDescriptor.targetIndex() : importDescriptor; WasmGlobal externalGlobal = lookupImportObject(instance, importDescriptor, imports, WasmGlobal.class); final int exportedValueType; + final SymbolTable.ClosedValueType exportedClosedValueType; final byte exportedMutability; if (externalGlobal != null) { exportedValueType = externalGlobal.getType(); + exportedClosedValueType = externalGlobal.getClosedValueType(); exportedMutability = externalGlobal.getMutability(); } else { final WasmInstance importedInstance = store.lookupModuleInstance(importedModuleName); @@ -328,11 +329,12 @@ void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescripto } exportedValueType = importedInstance.symbolTable().globalValueType(exportedGlobalIndex); + exportedClosedValueType = importedInstance.symbolTable().globalClosedValueType(exportedGlobalIndex); exportedMutability = importedInstance.symbolTable().globalMutability(exportedGlobalIndex); externalGlobal = importedInstance.externalGlobal(exportedGlobalIndex); } - if (exportedValueType != valueType) { + if (instance.symbolTable().closedTypeAt(valueType).matchesType(exportedClosedValueType)) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE, "Global variable '" + importedGlobalName + "' is imported into module '" + instance.name() + "' with the type " + WasmType.toString(valueType) + ", " + "'but it was exported in the module '" + importedModuleName + "' with the type " + WasmType.toString(exportedValueType) + "."); @@ -538,7 +540,9 @@ void resolveTagImport(WasmStore store, WasmInstance instance, ImportDescriptor i } importedTag = importedInstance.tag(exportedTagIndex); } - Assert.assertTrue(type.matchesType(importedTag.type()), Failure.INCOMPATIBLE_IMPORT_TYPE); + // matching for tag types does not work by subtyping, but requires equivalent types, + // A <= B and B <= A + Assert.assertTrue(type.matchesType(importedTag.type()) && importedTag.type().matchesType(type), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTag(tagIndex, importedTag); }; resolutionDag.resolveLater(new ImportTagSym(instance.name(), importDescriptor, tagIndex), new Sym[]{new ExportTagSym(importedModuleName, importedTagName)}, resolveAction); @@ -849,7 +853,7 @@ void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor // MAX_TABLE_DECLARATION_SIZE, so this condition will pass. assertUnsignedIntLessOrEqual(declaredMinSize, importedTable.minSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); assertUnsignedIntGreaterOrEqual(declaredMaxSize, importedTable.declaredMaxSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); - assertIntEqual(elemType, importedTable.elemType(), Failure.INCOMPATIBLE_IMPORT_TYPE); + assertTrue(instance.symbolTable().closedTypeAt(elemType).matchesType(importedTable.closedValueType()), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTableAddress(tableIndex, tableAddress); }; final ImportTableSym importTableSym = new ImportTableSym(instance.name(), importDescriptor); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index f3d923ff4c0b..06f306a029bf 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -1092,6 +1092,10 @@ public int globalValueType(int index) { return globalTypes[index]; } + public ClosedValueType globalClosedValueType(int index) { + return closedTypeAt(globalTypes[index]); + } + private byte globalFlags(int index) { return globalFlags[index]; } From 8c4c4f55f4cdeede13995bb17ebb2a9ff29104ab Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Fri, 10 Oct 2025 16:04:40 +0200 Subject: [PATCH 19/40] Store pre-computed closed types in SymbolTable --- .../src/org/graalvm/wasm/BinaryParser.java | 1 + .../src/org/graalvm/wasm/Linker.java | 4 +-- .../src/org/graalvm/wasm/SymbolTable.java | 28 ++++++++++++++----- .../src/org/graalvm/wasm/WasmFunction.java | 4 ++- .../src/org/graalvm/wasm/WasmTable.java | 2 +- .../org/graalvm/wasm/globals/WasmGlobal.java | 2 +- 6 files changed, 29 insertions(+), 12 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index 8d4c39bb94d7..50246cbad1f8 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -3216,6 +3216,7 @@ private void readFunctionType() { for (int resultIdx = 0; resultIdx < resultCount; resultIdx++) { module.symbolTable().registerFunctionTypeResultType(funcTypeIdx, resultIdx, resultTypes[resultIdx]); } + module.symbolTable().finishFunctionType(funcTypeIdx); } protected int readValueType() { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 2f730135cda3..261bae11c56d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -334,7 +334,7 @@ void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescripto externalGlobal = importedInstance.externalGlobal(exportedGlobalIndex); } - if (instance.symbolTable().closedTypeAt(valueType).matchesType(exportedClosedValueType)) { + if (instance.symbolTable().makeClosedType(valueType).matchesType(exportedClosedValueType)) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE, "Global variable '" + importedGlobalName + "' is imported into module '" + instance.name() + "' with the type " + WasmType.toString(valueType) + ", " + "'but it was exported in the module '" + importedModuleName + "' with the type " + WasmType.toString(exportedValueType) + "."); @@ -853,7 +853,7 @@ void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor // MAX_TABLE_DECLARATION_SIZE, so this condition will pass. assertUnsignedIntLessOrEqual(declaredMinSize, importedTable.minSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); assertUnsignedIntGreaterOrEqual(declaredMaxSize, importedTable.declaredMaxSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); - assertTrue(instance.symbolTable().closedTypeAt(elemType).matchesType(importedTable.closedValueType()), Failure.INCOMPATIBLE_IMPORT_TYPE); + assertTrue(instance.symbolTable().makeClosedType(elemType).matchesType(importedTable.closedValueType()), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTableAddress(tableIndex, tableAddress); }; final ImportTableSym importTableSym = new ImportTableSym(instance.name(), importDescriptor); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index 06f306a029bf..adedac876094 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -46,6 +46,7 @@ import static org.graalvm.wasm.WasmMath.maxUnsigned; import static org.graalvm.wasm.WasmMath.minUnsigned; +import java.lang.reflect.Array; import java.util.ArrayList; import java.util.List; @@ -402,6 +403,8 @@ public record TagInfo(byte attribute, int typeIndex) { */ @CompilationFinal(dimensions = 1) private int[] typeOffsets; + @CompilationFinal(dimensions = 1) private ClosedValueType[] closedTypes; + @CompilationFinal private int typeDataSize; @CompilationFinal private int typeCount; @@ -608,6 +611,7 @@ public record TagInfo(byte attribute, int typeIndex) { CompilerAsserts.neverPartOfCompilation(); this.typeData = new int[INITIAL_DATA_SIZE]; this.typeOffsets = new int[INITIAL_TYPE_SIZE]; + this.closedTypes = new ClosedValueType[INITIAL_TYPE_SIZE]; this.typeDataSize = 0; this.typeCount = 0; this.importedSymbols = new ArrayList<>(); @@ -670,8 +674,8 @@ private static int[] reallocate(int[] array, int currentSize, int newLength) { return newArray; } - private static WasmFunction[] reallocate(WasmFunction[] array, int currentSize, int newLength) { - WasmFunction[] newArray = new WasmFunction[newLength]; + private static T[] reallocate(T[] array, int currentSize, int newLength) { + T[] newArray = (T[]) Array.newInstance(array.getClass().getComponentType(), newLength); System.arraycopy(array, 0, newArray, 0, currentSize); return newArray; } @@ -702,6 +706,7 @@ private void ensureTypeCapacity(int index) { if (typeOffsets.length <= index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * typeOffsets.length); typeOffsets = reallocate(typeOffsets, typeCount, newLength); + closedTypes = reallocate(closedTypes, typeCount, newLength); } } @@ -732,6 +737,7 @@ public int allocateFunctionType(int[] paramTypes, int[] resultTypes, boolean isM for (int i = 0; i < resultTypes.length; i++) { registerFunctionTypeResultType(typeIdx, i, resultTypes[i]); } + finishFunctionType(typeIdx); return typeIdx; } @@ -747,6 +753,10 @@ void registerFunctionTypeResultType(int funcTypeIdx, int resultIdx, int type) { typeData[idx] = type; } + void finishFunctionType(int funcTypeIdx) { + closedTypes[funcTypeIdx] = makeClosedType(funcTypeIdx); + } + private void ensureFunctionsCapacity(int index) { if (functions.length <= index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * functions.length); @@ -860,16 +870,20 @@ int typeCount() { public ClosedFunctionType closedFunctionTypeAt(int typeIndex) { ClosedValueType[] paramTypes = new ClosedValueType[functionTypeParamCount(typeIndex)]; for (int i = 0; i < paramTypes.length; i++) { - paramTypes[i] = closedTypeAt(functionTypeParamTypeAt(typeIndex, i)); + paramTypes[i] = makeClosedType(functionTypeParamTypeAt(typeIndex, i)); } ClosedValueType[] resultTypes = new ClosedValueType[functionTypeResultCount(typeIndex)]; for (int i = 0; i < resultTypes.length; i++) { - resultTypes[i] = closedTypeAt(functionTypeResultTypeAt(typeIndex, i)); + resultTypes[i] = makeClosedType(functionTypeResultTypeAt(typeIndex, i)); } return new ClosedFunctionType(paramTypes, resultTypes); } - public ClosedValueType closedTypeAt(int type) { + public ClosedValueType closedTypeAt(int typeIndex) { + return closedTypes[typeIndex]; + } + + public ClosedValueType makeClosedType(int type) { return switch (type) { case WasmType.I32_TYPE -> NumberType.I32; case WasmType.I64_TYPE -> NumberType.I64; @@ -896,7 +910,7 @@ yield switch (WasmType.getAbstractHeapType(type)) { } public boolean matches(int expectedType, int actualType) { - return closedTypeAt(expectedType).matchesType(closedTypeAt(actualType)); + return makeClosedType(expectedType).matchesType(makeClosedType(actualType)); } public void importSymbol(ImportDescriptor descriptor) { @@ -1093,7 +1107,7 @@ public int globalValueType(int index) { } public ClosedValueType globalClosedValueType(int index) { - return closedTypeAt(globalTypes[index]); + return makeClosedType(globalTypes[index]); } private byte globalFlags(int index) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java index d2af6a9c437b..7f286753cc37 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java @@ -49,6 +49,7 @@ public final class WasmFunction { private final int index; private final ImportDescriptor importDescriptor; private final int typeIndex; + private final SymbolTable.ClosedFunctionType closedFunctionType; @CompilationFinal private String debugName; @CompilationFinal private CallTarget callTarget; /** Interop call adapter for argument and return value validation and conversion. */ @@ -62,6 +63,7 @@ public WasmFunction(SymbolTable symbolTable, int index, int typeIndex, ImportDes this.index = index; this.importDescriptor = importDescriptor; this.typeIndex = typeIndex; + this.closedFunctionType = symbolTable.closedFunctionTypeAt(typeIndex); } public String moduleName() { @@ -145,7 +147,7 @@ public int typeIndex() { } public SymbolTable.ClosedFunctionType closedType() { - return symbolTable.closedFunctionTypeAt(typeIndex()); + return closedFunctionType; } public int index() { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java index 6b66dbc6d16e..afde7747e9fe 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java @@ -171,7 +171,7 @@ public int elemType() { } public SymbolTable.ClosedValueType closedValueType() { - return symbolTable.closedTypeAt(elemType); + return symbolTable.makeClosedType(elemType); } /** diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java index cb27b08d2f06..68974c2524f7 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java @@ -111,7 +111,7 @@ public int getType() { public SymbolTable.ClosedValueType getClosedValueType() { if (symbolTable != null) { - return symbolTable.closedTypeAt(getType()); + return symbolTable.makeClosedType(getType()); } else { // Global was created by WebAssembly#global_alloc return switch (ValueType.fromValue(getType())) { From d58b3ce9af7dae0171487fd5f619b21bad532efc Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Fri, 17 Oct 2025 15:49:10 +0200 Subject: [PATCH 20/40] Divide matchesType into isSupertypeOf and isSubtypeOf This makes sure we always dispatch on a compilation final constant. --- .../src/org/graalvm/wasm/Linker.java | 10 +- .../src/org/graalvm/wasm/SymbolTable.java | 121 +++++++++++++++--- .../wasm/api/InteropCallAdapterNode.java | 2 +- .../src/org/graalvm/wasm/api/ValueType.java | 1 + .../graalvm/wasm/nodes/WasmFunctionNode.java | 2 +- 5 files changed, 113 insertions(+), 23 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 261bae11c56d..8e7d3f47d97b 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -334,7 +334,7 @@ void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescripto externalGlobal = importedInstance.externalGlobal(exportedGlobalIndex); } - if (instance.symbolTable().makeClosedType(valueType).matchesType(exportedClosedValueType)) { + if (instance.symbolTable().makeClosedType(valueType).isSupertypeOf(exportedClosedValueType)) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE, "Global variable '" + importedGlobalName + "' is imported into module '" + instance.name() + "' with the type " + WasmType.toString(valueType) + ", " + "'but it was exported in the module '" + importedModuleName + "' with the type " + WasmType.toString(exportedValueType) + "."); @@ -403,7 +403,7 @@ void resolveFunctionImport(WasmStore store, WasmInstance instance, WasmFunction Object externalFunctionInstance = lookupImportObject(instance, importDescriptor, imports, Object.class); if (externalFunctionInstance != null) { if (externalFunctionInstance instanceof WasmFunctionInstance functionInstance) { - if (!function.closedType().matchesType(functionInstance.function().closedType())) { + if (!function.closedType().isSupertypeOf(functionInstance.function().closedType())) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE); } instance.setTarget(function.index(), functionInstance.target()); @@ -436,7 +436,7 @@ void resolveFunctionImport(WasmStore store, WasmInstance instance, WasmFunction throw WasmException.create(Failure.UNKNOWN_IMPORT, "The imported function '" + function.importedFunctionName() + "', referenced in the module '" + instance.name() + "', does not exist in the imported module '" + function.importedModuleName() + "'."); } - if (!function.closedType().matchesType(importedFunction.closedType())) { + if (!function.closedType().isSupertypeOf(importedFunction.closedType())) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE); } final CallTarget target = importedInstance.target(importedFunction.index()); @@ -542,7 +542,7 @@ void resolveTagImport(WasmStore store, WasmInstance instance, ImportDescriptor i } // matching for tag types does not work by subtyping, but requires equivalent types, // A <= B and B <= A - Assert.assertTrue(type.matchesType(importedTag.type()) && importedTag.type().matchesType(type), Failure.INCOMPATIBLE_IMPORT_TYPE); + Assert.assertTrue(type.isSupertypeOf(importedTag.type()) && importedTag.type().isSupertypeOf(type), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTag(tagIndex, importedTag); }; resolutionDag.resolveLater(new ImportTagSym(instance.name(), importDescriptor, tagIndex), new Sym[]{new ExportTagSym(importedModuleName, importedTagName)}, resolveAction); @@ -853,7 +853,7 @@ void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor // MAX_TABLE_DECLARATION_SIZE, so this condition will pass. assertUnsignedIntLessOrEqual(declaredMinSize, importedTable.minSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); assertUnsignedIntGreaterOrEqual(declaredMaxSize, importedTable.declaredMaxSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); - assertTrue(instance.symbolTable().makeClosedType(elemType).matchesType(importedTable.closedValueType()), Failure.INCOMPATIBLE_IMPORT_TYPE); + assertTrue(instance.symbolTable().makeClosedType(elemType).isSupertypeOf(importedTable.closedValueType()), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTableAddress(tableIndex, tableAddress); }; final ImportTableSym importTableSym = new ImportTableSym(instance.name(), importDescriptor); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index adedac876094..b82be4b6af4a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -92,10 +92,13 @@ public enum Kind { Number, Vector, Reference, - Bottom + Bottom, + Top } - public abstract boolean matchesType(ClosedValueType valueSubType); + public abstract boolean isSupertypeOf(ClosedValueType valueSubType); + + public abstract boolean isSubtypeOf(ClosedValueType valueSuperType); public abstract boolean matchesValue(Object value); @@ -109,7 +112,9 @@ public enum Kind { Function } - public abstract boolean matchesType(ClosedHeapType heapSubType); + public abstract boolean isSupertypeOf(ClosedHeapType heapSubType); + + public abstract boolean isSubtypeOf(ClosedHeapType heapSuperType); public abstract boolean matchesValue(Object value); @@ -133,10 +138,15 @@ public int value() { } @Override - public boolean matchesType(ClosedValueType valueSubType) { + public boolean isSupertypeOf(ClosedValueType valueSubType) { return valueSubType == BottomType.BOTTOM || valueSubType == this; } + @Override + public boolean isSubtypeOf(ClosedValueType valueSuperType) { + return valueSuperType == TopType.TOP || valueSuperType == this; + } + @Override public boolean matchesValue(Object val) { return switch (value()) { @@ -168,8 +178,13 @@ public int value() { } @Override - public boolean matchesType(ClosedValueType valueSubType) { - return valueSubType == BottomType.BOTTOM || valueSubType == this; + public boolean isSupertypeOf(ClosedValueType valueSubType) { + return valueSubType == BottomType.BOTTOM || valueSubType == V128; + } + + @Override + public boolean isSubtypeOf(ClosedValueType valueSuperType) { + return valueSuperType == TopType.TOP || valueSuperType == V128; } @Override @@ -208,9 +223,15 @@ public ClosedHeapType heapType() { } @Override - public boolean matchesType(ClosedValueType valueSubType) { + public boolean isSupertypeOf(ClosedValueType valueSubType) { return valueSubType == BottomType.BOTTOM || valueSubType instanceof ClosedReferenceType referenceSubType && (!referenceSubType.nullable || this.nullable) && - this.closedHeapType.matchesType(referenceSubType.closedHeapType); + this.closedHeapType.isSupertypeOf(referenceSubType.closedHeapType); + } + + @Override + public boolean isSubtypeOf(ClosedValueType valueSuperType) { + return valueSuperType == TopType.TOP || valueSuperType instanceof ClosedReferenceType referencedSuperType && (!this.nullable || referencedSuperType.nullable) && + this.closedHeapType.isSubtypeOf(referencedSuperType.closedHeapType); } @Override @@ -240,7 +261,7 @@ public int value() { } @Override - public boolean matchesType(ClosedHeapType heapSubType) { + public boolean isSupertypeOf(ClosedHeapType heapSubType) { return switch (this.value) { case WasmType.FUNC_HEAPTYPE -> heapSubType == FUNC || heapSubType instanceof ClosedFunctionType; case WasmType.EXTERN_HEAPTYPE -> heapSubType == EXTERN; @@ -249,6 +270,11 @@ public boolean matchesType(ClosedHeapType heapSubType) { }; } + @Override + public boolean isSubtypeOf(ClosedHeapType heapSuperType) { + return heapSuperType == this; + } + @Override public boolean matchesValue(Object val) { return switch (this.value) { @@ -283,7 +309,7 @@ public ClosedValueType[] resultTypes() { } @Override - public boolean matchesType(ClosedHeapType heapSubType) { + public boolean isSupertypeOf(ClosedHeapType heapSubType) { if (!(heapSubType instanceof ClosedFunctionType functionSubType)) { return false; } @@ -291,7 +317,8 @@ public boolean matchesType(ClosedHeapType heapSubType) { return false; } for (int i = 0; i < this.paramTypes.length; i++) { - if (!functionSubType.paramTypes[i].matchesType(this.paramTypes[i])) { + CompilerAsserts.partialEvaluationConstant(this.paramTypes[i]); + if (!this.paramTypes[i].isSubtypeOf(functionSubType.paramTypes[i])) { return false; } } @@ -299,7 +326,37 @@ public boolean matchesType(ClosedHeapType heapSubType) { return false; } for (int i = 0; i < this.resultTypes.length; i++) { - if (!this.resultTypes[i].matchesType(functionSubType.resultTypes[i])) { + CompilerAsserts.partialEvaluationConstant(this.resultTypes[i]); + if (!this.resultTypes[i].isSupertypeOf(functionSubType.resultTypes[i])) { + return false; + } + } + return true; + } + + @Override + public boolean isSubtypeOf(ClosedHeapType heapSuperType) { + if (heapSuperType == AbstractHeapType.FUNC) { + return true; + } + if (!(heapSuperType instanceof ClosedFunctionType functionSuperType)) { + return false; + } + if (this.paramTypes.length != functionSuperType.paramTypes.length) { + return false; + } + for (int i = 0; i < this.paramTypes.length; i++) { + CompilerAsserts.partialEvaluationConstant(this.paramTypes[i]); + if (!this.paramTypes[i].isSupertypeOf(functionSuperType.paramTypes[i])) { + return false; + } + } + if (this.resultTypes.length != functionSuperType.resultTypes.length) { + return false; + } + for (int i = 0; i < this.resultTypes.length; i++) { + CompilerAsserts.partialEvaluationConstant(this.resultTypes[i]); + if (!this.resultTypes[i].isSubtypeOf(functionSuperType.resultTypes[i])) { return false; } } @@ -308,7 +365,7 @@ public boolean matchesType(ClosedHeapType heapSubType) { @Override public boolean matchesValue(Object value) { - return value instanceof WasmFunctionInstance instance && matchesType(instance.function().closedType()); + return value instanceof WasmFunctionInstance instance && isSupertypeOf(instance.function().closedType()); } @Override @@ -324,8 +381,13 @@ private BottomType() { } @Override - public boolean matchesType(ClosedValueType valueSubType) { - return valueSubType instanceof BottomType; + public boolean isSupertypeOf(ClosedValueType valueSubType) { + return valueSubType == BOTTOM; + } + + @Override + public boolean isSubtypeOf(ClosedValueType valueSuperType) { + return true; } @Override @@ -339,6 +401,33 @@ public Kind kind() { } } + public static final class TopType extends ClosedValueType { + public static final TopType TOP = new TopType(); + + private TopType() { + } + + @Override + public boolean isSupertypeOf(ClosedValueType valueSubType) { + return true; + } + + @Override + public boolean isSubtypeOf(ClosedValueType valueSuperType) { + return valueSuperType == TOP; + } + + @Override + public boolean matchesValue(Object value) { + return false; + } + + @Override + public Kind kind() { + return Kind.Top; + } + } + /** * @param initialSize Lower bound on table size. * @param maximumSize Upper bound on table size. @@ -910,7 +999,7 @@ yield switch (WasmType.getAbstractHeapType(type)) { } public boolean matches(int expectedType, int actualType) { - return makeClosedType(expectedType).matchesType(makeClosedType(actualType)); + return makeClosedType(expectedType).isSupertypeOf(makeClosedType(actualType)); } public void importSymbol(ImportDescriptor descriptor) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java index f791a52c852f..2719b169c94b 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java @@ -183,7 +183,7 @@ yield switch (numberType.value()) { objectMultiValueStack[i] = null; yield obj; } - case Bottom -> throw CompilerDirectives.shouldNotReachHere(); + case Bottom, Top -> throw CompilerDirectives.shouldNotReachHere(); }; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java index 973ab62a8eab..90ab8037f4ea 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java @@ -132,6 +132,7 @@ yield switch (abstractHeapType.value()) { }; } case Bottom -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: bottom"); + case Top -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: top"); }; } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index 52d09248e835..a4a1c468714a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -587,7 +587,7 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i // Validate that the target function type matches the expected type of the // indirect call. - if (!symtab.closedTypeAt(expectedFunctionTypeIndex).matchesType(new SymbolTable.ClosedReferenceType(false, function.closedType()))) { + if (!symtab.closedTypeAt(expectedFunctionTypeIndex).isSupertypeOf(new SymbolTable.ClosedReferenceType(false, function.closedType()))) { enterErrorBranch(); failFunctionTypeCheck(function, expectedFunctionTypeIndex); } From c2ba9276a5eb3423e40cc9e0c8de8c10e90f7437 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Fri, 17 Oct 2025 21:18:36 +0200 Subject: [PATCH 21/40] Add missing ExplodeLoop to subtype checking --- .../src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index b82be4b6af4a..072e3ddaf193 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -50,6 +50,7 @@ import java.util.ArrayList; import java.util.List; +import com.oracle.truffle.api.nodes.ExplodeLoop; import org.graalvm.collections.EconomicMap; import org.graalvm.collections.EconomicSet; import org.graalvm.collections.MapCursor; @@ -309,6 +310,7 @@ public ClosedValueType[] resultTypes() { } @Override + @ExplodeLoop(kind = ExplodeLoop.LoopExplosionKind.FULL_UNROLL) public boolean isSupertypeOf(ClosedHeapType heapSubType) { if (!(heapSubType instanceof ClosedFunctionType functionSubType)) { return false; @@ -335,6 +337,7 @@ public boolean isSupertypeOf(ClosedHeapType heapSubType) { } @Override + @ExplodeLoop(kind = ExplodeLoop.LoopExplosionKind.FULL_UNROLL) public boolean isSubtypeOf(ClosedHeapType heapSuperType) { if (heapSuperType == AbstractHeapType.FUNC) { return true; From 1792ee5b4bb46d69de03a20c7719fafc140809bd Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Thu, 9 Oct 2025 10:15:12 +0200 Subject: [PATCH 22/40] Do not pass in non-wasm functions as funcref values The funcref type only admits wasm functions. External functions must go through the import object, where they are wrapped in ExecuteHostFunctionNode. --- .../org/graalvm/wasm/test/WasmJsApiSuite.java | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java index 30fe023253ec..6959417dfda9 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/WasmJsApiSuite.java @@ -1411,7 +1411,14 @@ public void testFuncTypeMultiValue() throws IOException, InterruptedException { @Test public void testMultiValueReferencePassThrough() throws IOException, InterruptedException { - final byte[] source = compileWat("data", """ + final byte[] source1 = compileWat("data", """ + (module + (type (func (result i32))) + (func (export "func") (type 0) + i32.const 42 + )) + """); + final byte[] source2 = compileWat("data", """ (module (type (func (result funcref externref))) (import "m" "f" (func (type 0))) @@ -1421,7 +1428,8 @@ public void testMultiValueReferencePassThrough() throws IOException, Interrupted """); runTest(context -> { final WebAssembly wasm = new WebAssembly(context); - final var func = new Executable((args) -> 0); + final WasmInstance instance1 = moduleInstantiate(wasm, source1, null); + final Object func = WebAssembly.instanceExport(instance1, "func"); final var f = new Executable((args) -> { final Object[] result = new Object[2]; result[0] = func; @@ -1429,8 +1437,8 @@ public void testMultiValueReferencePassThrough() throws IOException, Interrupted return InteropArray.create(result); }); final Dictionary importObject = Dictionary.create(new Object[]{"m", Dictionary.create(new Object[]{"f", f})}); - final WasmInstance instance = moduleInstantiate(wasm, source, importObject); - final Object main = WebAssembly.instanceExport(instance, "main"); + final WasmInstance instance2 = moduleInstantiate(wasm, source2, importObject); + final Object main = WebAssembly.instanceExport(instance2, "main"); final InteropLibrary lib = InteropLibrary.getUncached(); try { Object result = lib.execute(main); From 130b9660a53d2b38e390ee8f8d79f70d154162d5 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Thu, 9 Oct 2025 12:40:06 +0200 Subject: [PATCH 23/40] Drop unused branches in set_callback functions --- .../src/org/graalvm/wasm/api/WebAssembly.java | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java index cae2a1166adb..3ed11375b77f 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java @@ -652,16 +652,6 @@ public static long memGrow(WasmMemory memory, int delta) { private static Object memSetGrowCallback(Object[] args) { checkArgumentCount(args, 1); InteropLibrary lib = InteropLibrary.getUncached(); - if (args.length > 1) { - // TODO: drop this branch after JS adopts the single-argument version - if (!(args[0] instanceof WasmMemory)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be executable"); - } - if (!lib.isExecutable(args[1])) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Second argument must be executable"); - } - return memSetGrowCallback(args[1]); - } if (!lib.isExecutable(args[0])) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Argument must be executable"); } @@ -690,16 +680,6 @@ public static void invokeMemGrowCallback(WasmMemory memory) { private static Object memSetNotifyCallback(Object[] args) { checkArgumentCount(args, 1); InteropLibrary lib = InteropLibrary.getUncached(); - if (args.length > 1) { - // TODO: drop this branch after JS adopts the single-argument version - if (!(args[0] instanceof WasmMemory)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be executable"); - } - if (!lib.isExecutable(args[1])) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Second argument must be executable"); - } - return memSetNotifyCallback(args[1]); - } if (!lib.isExecutable(args[0])) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Argument must be executable"); } @@ -729,16 +709,6 @@ public static int invokeMemNotifyCallback(Node node, WasmMemory memory, long add private static Object memSetWaitCallback(Object[] args) { checkArgumentCount(args, 1); InteropLibrary lib = InteropLibrary.getUncached(); - if (args.length > 1) { - // TODO: drop this branch after JS adopts the single-argument version - if (!(args[0] instanceof WasmMemory)) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "First argument must be executable"); - } - if (!lib.isExecutable(args[1])) { - throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Second argument must be executable"); - } - return memSetWaitCallback(args[1]); - } if (!lib.isExecutable(args[0])) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Argument must be executable"); } From 1006ca5af561502aa503b5c6f581fcdafa365e7e Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Thu, 9 Oct 2025 18:22:25 +0200 Subject: [PATCH 24/40] Use subtype matching when checking the type of empty else branch --- .../wasm/parser/validation/BlockFrame.java | 11 ++++++----- .../wasm/parser/validation/ControlFrame.java | 11 ++++++++++- .../graalvm/wasm/parser/validation/IfFrame.java | 16 ++++++++++++---- .../wasm/parser/validation/LoopFrame.java | 2 +- .../wasm/parser/validation/ParserState.java | 2 +- 5 files changed, 30 insertions(+), 12 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java index fdfba38a1d2c..046958abc330 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/BlockFrame.java @@ -44,6 +44,7 @@ import java.util.ArrayList; import java.util.BitSet; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmType; import org.graalvm.wasm.collection.IntArrayList; import org.graalvm.wasm.exception.Failure; @@ -57,24 +58,24 @@ class BlockFrame extends ControlFrame { private final IntArrayList branches; private final ArrayList exceptionHandlers; - private BlockFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, BitSet initializedLocals) { - super(paramTypes, resultTypes, initialStackSize, initializedLocals); + private BlockFrame(int[] paramTypes, int[] resultTypes, SymbolTable symbolTable, int initialStackSize, BitSet initializedLocals) { + super(paramTypes, resultTypes, symbolTable, initialStackSize, initializedLocals); branches = new IntArrayList(); exceptionHandlers = new ArrayList<>(); } BlockFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, ControlFrame parentFrame) { - this(paramTypes, resultTypes, initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); + this(paramTypes, resultTypes, parentFrame.getSymbolTable(), initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); } - static BlockFrame createFunctionFrame(int[] paramTypes, int[] resultTypes, int[] locals) { + static BlockFrame createFunctionFrame(int[] paramTypes, int[] resultTypes, int[] locals, SymbolTable symbolTable) { BitSet initializedLocals = new BitSet(locals.length); for (int localIndex = 0; localIndex < locals.length; localIndex++) { if (localIndex < paramTypes.length || WasmType.hasDefaultValue(locals[localIndex])) { initializedLocals.set(localIndex); } } - return new BlockFrame(paramTypes, resultTypes, 0, initializedLocals); + return new BlockFrame(paramTypes, resultTypes, symbolTable, 0, initializedLocals); } @Override diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java index a8a211627e5d..b522b2003ad2 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java @@ -41,6 +41,7 @@ package org.graalvm.wasm.parser.validation; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmType; import org.graalvm.wasm.parser.bytecode.RuntimeBytecodeGen; @@ -52,6 +53,7 @@ public abstract class ControlFrame { private final int[] paramTypes; private final int[] resultTypes; + private final SymbolTable symbolTable; private final int initialStackSize; private boolean unreachable; @@ -61,11 +63,14 @@ public abstract class ControlFrame { /** * @param paramTypes The parameter value types of the block structure. * @param resultTypes The result value types of the block structure. + * @param symbolTable Necessary to look up the definitions of types in {@code paramTypes} and {@code resultTypes} * @param initialStackSize The size of the value stack when entering this block structure. + * @param initializedLocals The set of locals which are already initialized at the start of this function */ - ControlFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, BitSet initializedLocals) { + ControlFrame(int[] paramTypes, int[] resultTypes, SymbolTable symbolTable, int initialStackSize, BitSet initializedLocals) { this.paramTypes = paramTypes; this.resultTypes = resultTypes; + this.symbolTable = symbolTable; this.initialStackSize = initialStackSize; this.unreachable = false; commonResultType = WasmType.getCommonValueType(resultTypes); @@ -84,6 +89,10 @@ protected int resultTypeLength() { return resultTypes.length; } + protected SymbolTable getSymbolTable() { + return symbolTable; + } + /** * @return The union of all result types. */ diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java index 1dc4ec731a13..b061ba167999 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java @@ -42,7 +42,6 @@ package org.graalvm.wasm.parser.validation; import java.util.ArrayList; -import java.util.Arrays; import java.util.BitSet; import org.graalvm.wasm.collection.IntArrayList; @@ -62,7 +61,7 @@ class IfFrame extends ControlFrame { private boolean elseBranch; IfFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, ControlFrame parentFrame, int falseJumpLocation) { - super(paramTypes, resultTypes, initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); + super(paramTypes, resultTypes, parentFrame.getSymbolTable(), initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); this.branchTargets = new IntArrayList(); this.exceptionHandlers = new ArrayList<>(); this.parentFrame = parentFrame; @@ -89,8 +88,17 @@ void enterElse(ParserState state, RuntimeBytecodeGen bytecode) { @Override void exit(RuntimeBytecodeGen bytecode) { - if (!elseBranch && !Arrays.equals(paramTypes(), resultTypes())) { - throw WasmException.create(Failure.TYPE_MISMATCH, "Expected else branch. If with incompatible param and result types requires else branch."); + if (!elseBranch) { + if (resultTypes().length != paramTypes().length) { + throw WasmException.create(Failure.TYPE_MISMATCH, "Expected else branch. If with incompatible param and result types requires else branch."); + } + if (!isUnreachable()) { + for (int i = 0; i < resultTypes().length; i++) { + if (!getSymbolTable().matches(resultTypes()[i], paramTypes()[i])) { + throw WasmException.create(Failure.TYPE_MISMATCH, "Expected else branch. If with incompatible param and result types requires else branch."); + } + } + } } if (branchTargets.size() == 0 && exceptionHandlers.isEmpty()) { bytecode.patchLocation(falseJumpLocation, bytecode.location()); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java index 3aae0d9cfb3c..0575e3560070 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/LoopFrame.java @@ -54,7 +54,7 @@ class LoopFrame extends ControlFrame { private final int labelLocation; LoopFrame(int[] paramTypes, int[] resultTypes, int initialStackSize, ControlFrame parentFrame, int labelLocation) { - super(paramTypes, resultTypes, initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); + super(paramTypes, resultTypes, parentFrame.getSymbolTable(), initialStackSize, (BitSet) parentFrame.initializedLocals.clone()); this.labelLocation = labelLocation; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java index 2c8caa2e3608..d659e87e4187 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java @@ -266,7 +266,7 @@ private void unwindStack(int size) { } public void enterFunction(int[] paramTypes, int[] resultTypes, int[] locals) { - ControlFrame frame = BlockFrame.createFunctionFrame(paramTypes, resultTypes, locals); + ControlFrame frame = BlockFrame.createFunctionFrame(paramTypes, resultTypes, locals, symbolTable); controlStack.push(frame); } From 1a25822a6036ab12cf4c8498080edc7390881c94 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Thu, 9 Oct 2025 21:58:28 +0200 Subject: [PATCH 25/40] Revert to exnref in ValueType, forbid exnref in JS global helpers According to the wasm JS API spec, exnref is not a valid ValueType. However, we will need to enforce that the types v128 and exnref never cross the wasm and JS boundary and we will need to do so in JS (we allow v128 values crossing from wasm to the spec test harness and other embedders). Since we use ValueType as a simple form of type reflection, we will need to have exnref so that GraalJS can detect exnref in wasm type signatures. --- .../src/org/graalvm/wasm/SymbolTable.java | 26 ++++++++ .../wasm/api/ExecuteHostFunctionNode.java | 1 + .../src/org/graalvm/wasm/api/ValueType.java | 12 ++-- .../src/org/graalvm/wasm/api/WebAssembly.java | 66 +++++++------------ .../org/graalvm/wasm/exception/Failure.java | 2 +- .../wasm/exception/WasmJsApiException.java | 3 + .../org/graalvm/wasm/globals/WasmGlobal.java | 2 +- .../graalvm/wasm/nodes/WasmFunctionNode.java | 2 +- 8 files changed, 62 insertions(+), 52 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index 072e3ddaf193..9c8669e9b50a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -975,6 +975,32 @@ public ClosedValueType closedTypeAt(int typeIndex) { return closedTypes[typeIndex]; } + public static ClosedValueType closedTypeOfPredefined(int type) { + return switch (type) { + case WasmType.I32_TYPE -> NumberType.I32; + case WasmType.I64_TYPE -> NumberType.I64; + case WasmType.F32_TYPE -> NumberType.F32; + case WasmType.F64_TYPE -> NumberType.F64; + case WasmType.V128_TYPE -> VectorType.V128; + default -> { + if (WasmType.isBottomType(type)) { + yield BottomType.BOTTOM; + } + assert WasmType.isReferenceType(type); + boolean nullable = WasmType.isNullable(type); + yield switch (WasmType.getAbstractHeapType(type)) { + case WasmType.FUNC_HEAPTYPE -> nullable ? ClosedReferenceType.FUNCREF : ClosedReferenceType.NONNULL_FUNCREF; + case WasmType.EXTERN_HEAPTYPE -> nullable ? ClosedReferenceType.EXTERNREF : ClosedReferenceType.NONNULL_EXTERNREF; + case WasmType.EXN_HEAPTYPE -> nullable ? ClosedReferenceType.EXNREF : ClosedReferenceType.NONNULL_EXNREF; + default -> { + assert WasmType.isConcreteReferenceType(type); + throw new IllegalArgumentException(); + } + }; + } + }; + } + public ClosedValueType makeClosedType(int type) { return switch (type) { case WasmType.I32_TYPE -> NumberType.I32; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java index 363d202492ef..67e40d867ad3 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java @@ -177,6 +177,7 @@ private void pushMultiValueResult(Object result, int resultCount) { default -> { assert WasmType.isVectorType(resultType) || WasmType.isReferenceType(resultType); if (!closedResultType.matchesValue(result)) { + errorBranch.enter(); throw WasmException.create(Failure.INVALID_TYPE_IN_MULTI_VALUE); } objectMultiValueStack[i] = value; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java index 90ab8037f4ea..e1f3a499e3d4 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java @@ -42,10 +42,10 @@ import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmType; -import org.graalvm.wasm.exception.Failure; -import org.graalvm.wasm.exception.WasmException; import com.oracle.truffle.api.CompilerAsserts; +import org.graalvm.wasm.exception.Failure; +import org.graalvm.wasm.exception.WasmException; public enum ValueType { i32(WasmType.I32_TYPE), @@ -54,7 +54,8 @@ public enum ValueType { f64(WasmType.F64_TYPE), v128(WasmType.V128_TYPE), anyfunc(WasmType.FUNCREF_TYPE), - externref(WasmType.EXTERNREF_TYPE); + externref(WasmType.EXTERNREF_TYPE), + exnref(WasmType.EXNREF_TYPE); private final int value; @@ -75,7 +76,7 @@ public static ValueType fromValue(int value) { yield switch (WasmType.getAbstractHeapType(value)) { case WasmType.FUNC_HEAPTYPE -> anyfunc; case WasmType.EXTERN_HEAPTYPE -> externref; - case WasmType.EXN_HEAPTYPE -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: 0x" + Integer.toHexString(value)); + case WasmType.EXN_HEAPTYPE -> exnref; default -> { assert WasmType.isConcreteReferenceType(value); yield anyfunc; @@ -98,7 +99,7 @@ public static boolean isVectorType(ValueType valueType) { } public static boolean isReferenceType(ValueType valueType) { - return valueType == anyfunc || valueType == externref; + return valueType == anyfunc || valueType == externref || valueType == exnref; } public SymbolTable.ClosedValueType asClosedValueType() { @@ -110,6 +111,7 @@ public SymbolTable.ClosedValueType asClosedValueType() { case v128 -> SymbolTable.VectorType.V128; case anyfunc -> SymbolTable.ClosedReferenceType.FUNCREF; case externref -> SymbolTable.ClosedReferenceType.EXTERNREF; + case exnref -> SymbolTable.ClosedReferenceType.EXNREF; }; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java index 3ed11375b77f..9bca8ee57e2f 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java @@ -53,6 +53,7 @@ import org.graalvm.polyglot.io.ByteSequence; import org.graalvm.wasm.EmbedderDataHolder; import org.graalvm.wasm.ImportDescriptor; +import org.graalvm.wasm.SymbolTable; import org.graalvm.wasm.WasmConstant; import org.graalvm.wasm.WasmContext; import org.graalvm.wasm.WasmCustomSection; @@ -780,13 +781,14 @@ public WasmGlobal globalAlloc(ValueType valueType, boolean mutable, Object value case i64 -> WasmGlobal.alloc64(valueType, mutable, (long) value); case f32 -> WasmGlobal.alloc32(valueType, mutable, Float.floatToRawIntBits((float) value)); case f64 -> WasmGlobal.alloc64(valueType, mutable, Double.doubleToRawLongBits((double) value)); - case v128 -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Accessing v128 globals from JS is not allowed."); + case v128 -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.V128_VALUE_ACCESS); case anyfunc, externref -> { if (!refTypes) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); } yield WasmGlobal.allocRef(valueType, mutable, value); } + case exnref -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.EXNREF_VALUE_ACCESS); }; } @@ -799,18 +801,20 @@ private Object globalRead(Object[] args) { } public Object globalRead(WasmGlobal global) { - int type = global.getType(); - return switch (type) { + return switch (global.getType()) { case WasmType.I32_TYPE -> global.loadAsInt(); case WasmType.I64_TYPE -> global.loadAsLong(); case WasmType.F32_TYPE -> Float.intBitsToFloat(global.loadAsInt()); case WasmType.F64_TYPE -> Double.longBitsToDouble(global.loadAsLong()); - case WasmType.V128_TYPE -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Accessing v128 globals from JS is not allowed."); + case WasmType.V128_TYPE -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.V128_VALUE_ACCESS); default -> { - assert WasmType.isReferenceType(type); + assert WasmType.isReferenceType(global.getType()); if (!refTypes) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); } + if (SymbolTable.closedTypeOfPredefined(WasmType.EXNREF_TYPE).isSupertypeOf(global.getClosedValueType())) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.EXNREF_VALUE_ACCESS); + } yield global.loadAsReference(); } }; @@ -828,21 +832,23 @@ public Object globalWrite(WasmGlobal global, Object value) { if (!global.isMutable()) { throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global is not mutable."); } - int type = global.getType(); if (!global.getClosedValueType().matchesValue(value)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", ValueType.fromValue(type), value); + throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", ValueType.fromValue(global.getType()), value); } - switch (type) { + switch (global.getType()) { case WasmType.I32_TYPE -> global.storeInt((int) value); case WasmType.I64_TYPE -> global.storeLong((long) value); case WasmType.F32_TYPE -> global.storeInt(Float.floatToRawIntBits((float) value)); case WasmType.F64_TYPE -> global.storeLong(Double.doubleToRawLongBits((double) value)); - case WasmType.V128_TYPE -> throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Accessing v128 globals from JS is not allowed."); + case WasmType.V128_TYPE -> throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.V128_VALUE_ACCESS); default -> { - assert WasmType.isReferenceType(type); + assert WasmType.isReferenceType(global.getType()); if (!refTypes) { throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); } + if (SymbolTable.closedTypeOfPredefined(WasmType.EXNREF_TYPE).isSupertypeOf(global.getClosedValueType())) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.EXNREF_VALUE_ACCESS); + } global.storeReference(value); } } @@ -888,40 +894,12 @@ public Object exnAlloc(Object[] args) { for (int i = 0; i < paramCount; i++) { final ValueType paramType = type.paramTypeAt(i); final Object value = args[i + 1]; - switch (paramType) { - case i32: - if (!(value instanceof Integer)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); - } - break; - case i64: - if (!(value instanceof Long)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); - } - break; - case f32: - if (!(value instanceof Float)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); - } - break; - case f64: - if (!(value instanceof Double)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); - } - break; - case anyfunc: - if (!refTypes) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled"); - } - if (!(value == WasmConstant.NULL || value instanceof WasmFunctionInstance)) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); - } - break; - case externref: - if (!refTypes) { - throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled"); - } - break; + + if (!paramType.asClosedValueType().matchesValue(value)) { + throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Param type %s, value: %s", paramType, value); + } + if (ValueType.isReferenceType(paramType) && !refTypes) { + throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); } fields[i] = value; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java index 35e52721ec91..14a28d8591cb 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/Failure.java @@ -143,7 +143,7 @@ public enum Failure { OUT_OF_BOUNDS_MEMORY_ACCESS(Type.TRAP, "out of bounds memory access"), UNALIGNED_ATOMIC(Type.TRAP, "unaligned atomic"), EXPECTED_SHARED_MEMORY(Type.TRAP, "expected shared memory"), - INDIRECT_CALL_TYPE__MISMATCH(Type.TRAP, "indirect call type mismatch"), + INDIRECT_CALL_TYPE_MISMATCH(Type.TRAP, "indirect call type mismatch"), INVALID_MULTI_VALUE_ARITY(Type.TRAP, "provided multi-value size does not match function type"), INVALID_TYPE_IN_MULTI_VALUE(Type.TRAP, "type of value in multi-value does not match the function type"), diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/WasmJsApiException.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/WasmJsApiException.java index 20998bbde54b..45dfa81237d8 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/WasmJsApiException.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/exception/WasmJsApiException.java @@ -51,6 +51,9 @@ */ public class WasmJsApiException extends AbstractTruffleException { + public static final String V128_VALUE_ACCESS = "Invalid value type. Accessing v128 values from JS is not allowed."; + public static final String EXNREF_VALUE_ACCESS = "Invalid value type. Accessing exnref values from JS is not allowed."; + public enum Kind { TypeError, RangeError, diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java index 68974c2524f7..eaaba8a1369b 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java @@ -122,6 +122,7 @@ public SymbolTable.ClosedValueType getClosedValueType() { case v128 -> SymbolTable.VectorType.V128; case anyfunc -> SymbolTable.ClosedReferenceType.FUNCREF; case externref -> SymbolTable.ClosedReferenceType.EXTERNREF; + case exnref -> SymbolTable.ClosedReferenceType.EXNREF; }; } } @@ -238,7 +239,6 @@ void writeMember(String member, Object value, case WasmType.F32_TYPE -> storeInt(Float.floatToRawIntBits(valueLibrary.asFloat(value))); case WasmType.F64_TYPE -> storeLong(Double.doubleToRawLongBits(valueLibrary.asDouble(value))); case WasmType.V128_TYPE -> { - assert WasmType.isVectorType(type); if (!getClosedValueType().matchesValue(value)) { throw UnsupportedMessageException.create(); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index a4a1c468714a..56a1841bd1ee 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -1935,7 +1935,7 @@ private int pushExceptionFieldsAndReference(VirtualFrame frame, WasmRuntimeExcep @TruffleBoundary private void failFunctionTypeCheck(WasmFunction function, int expectedFunctionTypeIndex) { - throw WasmException.format(Failure.INDIRECT_CALL_TYPE__MISMATCH, this, + throw WasmException.format(Failure.INDIRECT_CALL_TYPE_MISMATCH, this, "Actual (type %d of function %s) and expected (type %d in module %s) types differ in the indirect call.", function.typeIndex(), function.name(), expectedFunctionTypeIndex, module.name()); } From 8576bb23d647f4d77ccadb547c5b7bf37d649393 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Thu, 9 Oct 2025 22:59:30 +0200 Subject: [PATCH 26/40] Document the new GraalWasm type system --- .../src/org/graalvm/wasm/WasmType.java | 73 ++++++++++++++++++- .../wasm/parser/validation/ControlFrame.java | 6 +- 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java index 9011d7e587ee..682e148de552 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java @@ -49,6 +49,44 @@ import com.oracle.truffle.api.library.ExportLibrary; import com.oracle.truffle.api.library.ExportMessage; +/** + *

+ * Wasm value types are represented as {@code int}s. Predefined types are negative values, while + * user-defined types are non-negative. The second-highest bit is used to signal if a reference type + * is nullable. For the predefined non-reference types (numbers and vectors), that bit is always + * set. + *

+ *

+ * For predefined types, the negative values are the LEB128 decodings of the bytes that represent + * these predefined types in the wasm binary format. E.g., {@code i32} is represented as + * {@code 0x7f} in the binary format, which decodes into {@code -1}, and that is the internal + * representation we use in {@link #I32_TYPE}. + *

+ *

+ * For user-defined types, the non-negative value is the type index which points to the type + * definition in the module's {@link SymbolTable}. + *

+ *

+ * Wasm heap types are represented using the same schema (a union of negative predefined types and + * user-defined non-negative type indices), but without using a special nullability bit. To build a + * reference type out of a heap type, set the nullability bit using {@link #withNullable}: + *

+ * + *
+ *     boolean nullable = ...;
+ *     int heapType = ...;
+ *     int referenceType = WasmType.withNullable(nullable, heapType);
+ * 
+ *

+ * During type checking, it is not enough to compare types by equality, as this does not take into + * account type aliases and subtyping. Instead, use {@link SymbolTable#matchesType(int, int)}. + *

+ *

+ * For an example of how to do case analysis on a Wasm value type, check the source of + * {@link #toString(int)}. NB: The types {@link #TOP} and {@link #BOT} only occur during module + * validation. + *

+ */ @ExportLibrary(InteropLibrary.class) @SuppressWarnings({"unused", "static-method"}) public class WasmType implements TruffleObject { @@ -95,10 +133,14 @@ public class WasmType implements TruffleObject { public static final int EXNREF_TYPE = EXN_HEAPTYPE; @CompilationFinal(dimensions = 1) public static final int[] EXNREF_TYPE_ARRAY = {EXNREF_TYPE}; + // Implementation-specific Types. /** - * Implementation-specific Types. + * The common supertype of all types, the universal type. */ public static final int TOP = -0x7e; + /** + * The common subtype of all types, the impossible type. + */ public static final int BOT = -0x7f; /** @@ -172,29 +214,58 @@ public static boolean isBottomType(int type) { return withNullable(true, type) == BOT; } + /** + * Indicates whether this is a user-defined reference type. + */ public static boolean isConcreteReferenceType(int type) { return type >= 0 || isBottomType(type); } + /** + * Returns the type index of this user-defined reference type. This must be used together with + * the appropriate {@link SymbolTable} to be able to expand the type's definition. + */ public static int getTypeIndex(int type) { assert isConcreteReferenceType(type); return type & TYPE_VALUE_MASK; } + /** + * Returns the "payload" of this reference type, which can be matched against the predefined + * abstract heap types (such as {@link #FUNC_HEAPTYPE} or {@link #EXTERN_HEAPTYPE}) in a switch + * statement. + */ public static int getAbstractHeapType(int type) { assert isReferenceType(type); return withNullable(true, type); } + /** + * Indicates whether this value types admits the value {@link WasmConstant#NULL}. Can only be + * called on reference types. + */ public static boolean isNullable(int type) { assert isReferenceType(type); return (type & TYPE_NULLABLE_MASK) != 0; } + /** + * Updates the nullability bit of this reference type. Can also be used to create a reference + * type from a heap type. + * + * @param nullable whether the resulting reference type should be nullable + * @param type a reference type or a heap type (one of the {@code *_HEAPTYPE} constants or a + * type index) + */ public static int withNullable(boolean nullable, int type) { return nullable ? type | TYPE_NULLABLE_MASK : type & ~TYPE_NULLABLE_MASK; } + /** + * Indicates whether this type has a default value (this is the case for all value types except + * for non-nullable reference types). Locals of such types do not have to be initialized prior + * to first access. + */ public static boolean hasDefaultValue(int type) { return !(isReferenceType(type) && !isNullable(type)); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java index b522b2003ad2..87a443c92401 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ControlFrame.java @@ -63,9 +63,11 @@ public abstract class ControlFrame { /** * @param paramTypes The parameter value types of the block structure. * @param resultTypes The result value types of the block structure. - * @param symbolTable Necessary to look up the definitions of types in {@code paramTypes} and {@code resultTypes} + * @param symbolTable Necessary to look up the definitions of types in {@code paramTypes} and + * {@code resultTypes} * @param initialStackSize The size of the value stack when entering this block structure. - * @param initializedLocals The set of locals which are already initialized at the start of this function + * @param initializedLocals The set of locals which are already initialized at the start of this + * function */ ControlFrame(int[] paramTypes, int[] resultTypes, SymbolTable symbolTable, int initialStackSize, BitSet initializedLocals) { this.paramTypes = paramTypes; From 5484f319d3e6bc2e045a897dd107385faba5c288 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Fri, 10 Oct 2025 10:48:24 +0200 Subject: [PATCH 27/40] Try checking type equivalence classes before subtype matching in call_indirect This reverts back to using type equivalence classes. We use them for fast type equality checks. In the optimistic case (expected type == actual type), this is enough and leads to a fast path that avoids invoking the subtype matching on every indirect call. If types are not equal, subtype matching is still performed. --- .../src/org/graalvm/wasm/Linker.java | 40 ++++++++-- .../src/org/graalvm/wasm/SymbolTable.java | 77 ++++++++++++++++++- .../src/org/graalvm/wasm/WasmFunction.java | 14 ++++ .../src/org/graalvm/wasm/WasmLanguage.java | 18 +++++ .../graalvm/wasm/nodes/WasmFunctionNode.java | 9 ++- 5 files changed, 147 insertions(+), 11 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 8e7d3f47d97b..b5f315993e18 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -191,12 +191,7 @@ private void tryLinkOutsidePartialEvaluation(WasmInstance entryPointInstance, Im ArrayList failures = new ArrayList<>(); final int maxStartFunctionIndex = runLinkActions(store, instances, importValues, failures); linkTopologically(store, failures, maxStartFunctionIndex); - for (WasmInstance instance : instances.values()) { - WasmModule module = instance.module(); - if (instance.isLinkInProgress() && !module.isParsed()) { - module.setParsed(); - } - } + assignTypeEquivalenceClasses(store); resolutionDag = null; runStartFunctions(instances, failures); checkFailures(failures); @@ -243,6 +238,39 @@ private void linkTopologically(WasmStore store, ArrayList failures, i } } + private static void assignTypeEquivalenceClasses(WasmStore store) { + final Map instances = store.moduleInstances(); + for (WasmInstance instance : instances.values()) { + WasmModule module = instance.module(); + if (instance.isLinkInProgress() && !module.isParsed()) { + assignTypeEquivalenceClasses(module, store.language()); + } + } + } + + private static void assignTypeEquivalenceClasses(WasmModule module, WasmLanguage language) { + var lock = module.getLock(); + lock.lock(); + try { + if (module.isParsed()) { + return; + } + final SymbolTable symtab = module.symbolTable(); + for (int index = 0; index < symtab.typeCount(); index++) { + SymbolTable.ClosedFunctionType type = symtab.closedFunctionTypeAt(index); + int equivalenceClass = language.equivalenceClassFor(type); + symtab.setEquivalenceClass(index, equivalenceClass); + } + for (int index = 0; index < symtab.numFunctions(); index++) { + final WasmFunction function = symtab.function(index); + function.setTypeEquivalenceClass(symtab.equivalenceClass(function.typeIndex())); + } + module.setParsed(); + } finally { + lock.unlock(); + } + } + private static void runStartFunctions(Map instances, ArrayList failures) { List instanceList = new ArrayList<>(instances.values()); instanceList.sort(Comparator.comparingInt(RuntimeState::startFunctionIndex)); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index 9c8669e9b50a..f3e3e785c664 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -48,6 +48,7 @@ import java.lang.reflect.Array; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import com.oracle.truffle.api.nodes.ExplodeLoop; @@ -86,6 +87,8 @@ public abstract class SymbolTable { private static final byte GLOBAL_FUNCTION_INITIALIZER_BIT = 0x20; public static final int UNINITIALIZED_ADDRESS = Integer.MIN_VALUE; + public static final int NO_EQUIVALENCE_CLASS = 0; + public static final int FIRST_EQUIVALENCE_CLASS = NO_EQUIVALENCE_CLASS + 1; public abstract static sealed class ClosedValueType { // This is a workaround until we can use pattern matching in JDK 21+. @@ -104,6 +107,11 @@ public enum Kind { public abstract boolean matchesValue(Object value); public abstract Kind kind(); + + @Override + public boolean equals(Object other) { + return other instanceof ClosedValueType otherValueType && isSubtypeOf(otherValueType) && isSupertypeOf(otherValueType); + } } public abstract static sealed class ClosedHeapType { @@ -120,6 +128,11 @@ public enum Kind { public abstract boolean matchesValue(Object value); public abstract Kind kind(); + + @Override + public boolean equals(Object other) { + return other instanceof ClosedHeapType otherHeapType && isSubtypeOf(otherHeapType) && isSupertypeOf(otherHeapType); + } } public static final class NumberType extends ClosedValueType { @@ -163,6 +176,11 @@ public boolean matchesValue(Object val) { public Kind kind() { return Kind.Number; } + + @Override + public int hashCode() { + return value; + } } public static final class VectorType extends ClosedValueType { @@ -197,6 +215,11 @@ public boolean matchesValue(Object val) { public Kind kind() { return Kind.Vector; } + + @Override + public int hashCode() { + return value; + } } public static final class ClosedReferenceType extends ClosedValueType { @@ -244,6 +267,11 @@ public boolean matchesValue(Object value) { public Kind kind() { return Kind.Reference; } + + @Override + public int hashCode() { + return Boolean.hashCode(nullable) ^ closedHeapType.hashCode(); + } } public static final class AbstractHeapType extends ClosedHeapType { @@ -290,6 +318,11 @@ public boolean matchesValue(Object val) { public Kind kind() { return Kind.Abstract; } + + @Override + public int hashCode() { + return value; + } } public static final class ClosedFunctionType extends ClosedHeapType { @@ -375,6 +408,11 @@ public boolean matchesValue(Object value) { public Kind kind() { return Kind.Function; } + + @Override + public int hashCode() { + return Arrays.hashCode(paramTypes) ^ Arrays.hashCode(resultTypes); + } } public static final class BottomType extends ClosedValueType { @@ -402,6 +440,11 @@ public boolean matchesValue(Object value) { public Kind kind() { return Kind.Bottom; } + + @Override + public int hashCode() { + return Integer.MIN_VALUE; + } } public static final class TopType extends ClosedValueType { @@ -429,6 +472,11 @@ public boolean matchesValue(Object value) { public Kind kind() { return Kind.Top; } + + @Override + public int hashCode() { + return Integer.MAX_VALUE; + } } /** @@ -496,6 +544,15 @@ public record TagInfo(byte attribute, int typeIndex) { @CompilationFinal(dimensions = 1) private int[] typeOffsets; @CompilationFinal(dimensions = 1) private ClosedValueType[] closedTypes; + /** + * Stores the type equivalence class. + *

+ * Since multiple types have the same shape, each type is mapped to an equivalence class, so + * that two types can be quickly compared. + *

+ * The equivalence classes are computed globally for all the modules, during linking. + */ + @CompilationFinal(dimensions = 1) private int[] typeEquivalenceClasses; @CompilationFinal private int typeDataSize; @CompilationFinal private int typeCount; @@ -704,6 +761,7 @@ public record TagInfo(byte attribute, int typeIndex) { this.typeData = new int[INITIAL_DATA_SIZE]; this.typeOffsets = new int[INITIAL_TYPE_SIZE]; this.closedTypes = new ClosedValueType[INITIAL_TYPE_SIZE]; + this.typeEquivalenceClasses = new int[INITIAL_TYPE_SIZE]; this.typeDataSize = 0; this.typeCount = 0; this.importedSymbols = new ArrayList<>(); @@ -787,9 +845,9 @@ private void ensureTypeDataCapacity(int index) { } /** - * Ensure that the {@link #typeOffsets} array has enough space to store the data for the type at - * {@code index}. If there is not enough space, then a reallocation of the array takes place, - * doubling its capacity. + * Ensure that the {@link #typeOffsets} and {@link #typeEquivalenceClasses} arrays have enough + * space to store the data for the type at {@code index}. If there is not enough space, then a + * reallocation of the array takes place, doubling its capacity. *

* No synchronisation is required for this method, as it is only called during parsing, which is * carried out by a single thread. @@ -799,6 +857,7 @@ private void ensureTypeCapacity(int index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * typeOffsets.length); typeOffsets = reallocate(typeOffsets, typeCount, newLength); closedTypes = reallocate(closedTypes, typeCount, newLength); + typeEquivalenceClasses = reallocate(typeEquivalenceClasses, typeCount, newLength); } } @@ -849,6 +908,18 @@ void finishFunctionType(int funcTypeIdx) { closedTypes[funcTypeIdx] = makeClosedType(funcTypeIdx); } + public int equivalenceClass(int typeIndex) { + return typeEquivalenceClasses[typeIndex]; + } + + void setEquivalenceClass(int index, int eqClass) { + checkNotParsed(); + if (typeEquivalenceClasses[index] != NO_EQUIVALENCE_CLASS) { + throw WasmException.create(Failure.UNSPECIFIED_INVALID, "Type at index " + index + " already has an equivalence class."); + } + typeEquivalenceClasses[index] = eqClass; + } + private void ensureFunctionsCapacity(int index) { if (functions.length <= index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * functions.length); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java index 7f286753cc37..a54608092eea 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java @@ -43,6 +43,8 @@ import com.oracle.truffle.api.CallTarget; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import org.graalvm.wasm.exception.Failure; +import org.graalvm.wasm.exception.WasmException; public final class WasmFunction { private final SymbolTable symbolTable; @@ -50,6 +52,7 @@ public final class WasmFunction { private final ImportDescriptor importDescriptor; private final int typeIndex; private final SymbolTable.ClosedFunctionType closedFunctionType; + @CompilationFinal private int typeEquivalenceClass; @CompilationFinal private String debugName; @CompilationFinal private CallTarget callTarget; /** Interop call adapter for argument and return value validation and conversion. */ @@ -90,6 +93,13 @@ public int resultTypeAt(int returnIndex) { return symbolTable.functionTypeResultTypeAt(typeIndex, returnIndex); } + void setTypeEquivalenceClass(int typeEquivalenceClass) { + if (this.typeEquivalenceClass != SymbolTable.NO_EQUIVALENCE_CLASS) { + throw WasmException.create(Failure.UNSPECIFIED_INVALID, "Function at index " + index + " already has an equivalence class."); + } + this.typeEquivalenceClass = typeEquivalenceClass; + } + public int[] resultTypes() { return symbolTable.functionTypeResultTypesAsArray(typeIndex); } @@ -150,6 +160,10 @@ public SymbolTable.ClosedFunctionType closedType() { return closedFunctionType; } + public int typeEquivalenceClass() { + return typeEquivalenceClass; + } + public int index() { return index; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java index c376dab7fb3e..514c954372c3 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java @@ -95,8 +95,26 @@ public final class WasmLanguage extends TruffleLanguage { private final Map builtinModules = new ConcurrentHashMap<>(); + private final Map equivalenceClasses = new ConcurrentHashMap<>(); + private int nextEquivalenceClass = SymbolTable.FIRST_EQUIVALENCE_CLASS; private final Map interopCallAdapters = new ConcurrentHashMap<>(); + public int equivalenceClassFor(SymbolTable.ClosedFunctionType type) { + CompilerAsserts.neverPartOfCompilation(); + Integer equivalenceClass = equivalenceClasses.get(type); + if (equivalenceClass == null) { + synchronized (this) { + equivalenceClass = equivalenceClasses.get(type); + if (equivalenceClass == null) { + equivalenceClass = nextEquivalenceClass++; + Integer prev = equivalenceClasses.put(type, equivalenceClass); + assert prev == null; + } + } + } + return equivalenceClass; + } + /** * Gets or creates the interop call adapter for a function type. Always returns the same call * target for any particular type. diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index 56a1841bd1ee..36f6c0f2b7e8 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -582,12 +582,17 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown table element type: %s", element); } + int expectedTypeEquivalenceClass = symtab.equivalenceClass(expectedFunctionTypeIndex); + // Target function instance must be from the same context. assert functionInstanceContext == WasmContext.get(this); // Validate that the target function type matches the expected type of the - // indirect call. - if (!symtab.closedTypeAt(expectedFunctionTypeIndex).isSupertypeOf(new SymbolTable.ClosedReferenceType(false, function.closedType()))) { + // indirect call. We first try if the types are equivalent using the + // equivalence classes. If they are not equivalent, we run the full subtype + // matching procedure. + if (expectedTypeEquivalenceClass != function.typeEquivalenceClass() && + !symtab.closedTypeAt(expectedFunctionTypeIndex).isSupertypeOf(new SymbolTable.ClosedReferenceType(false, function.closedType()))) { enterErrorBranch(); failFunctionTypeCheck(function, expectedFunctionTypeIndex); } From 0b78d7f3f3d7cbccad8b97a33ba8ad49127281d2 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Fri, 17 Oct 2025 09:09:16 +0200 Subject: [PATCH 28/40] Fix LABEL_U16 parsing in the interpreter --- .../src/org/graalvm/wasm/nodes/WasmFunctionNode.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index 36f6c0f2b7e8..78e35f703c1a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -326,7 +326,7 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i break; } case Bytecode.LABEL_U16: { - final int value = rawPeekU16(bytecode, offset); + final int value = rawPeekU8(bytecode, offset); final int stackSize = rawPeekU8(bytecode, offset + 1); offset += 2; final int resultCount = (value & BytecodeBitEncoding.LABEL_U16_RESULT_VALUE); From 64692361a7e8b8263851fd9d0af1525f568b9434 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Fri, 17 Oct 2025 08:47:08 +0200 Subject: [PATCH 29/40] Factor out bytecode handlers for MISC bytecodes --- .../org/graalvm/wasm/constants/Bytecode.java | 49 ++- .../graalvm/wasm/constants/StackEffects.java | 394 ++++++++++++++++++ .../constants/Vector128OpStackEffects.java | 335 --------------- .../graalvm/wasm/nodes/WasmFunctionNode.java | 317 +++++++------- .../wasm/parser/bytecode/BytecodeParser.java | 1 - 5 files changed, 585 insertions(+), 511 deletions(-) create mode 100644 wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/StackEffects.java delete mode 100644 wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Vector128OpStackEffects.java diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java index 5095fa779a27..b0c1fbb67e60 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java @@ -355,35 +355,34 @@ public class Bytecode { public static final int MEMORY_INIT = 0x08; public static final int MEMORY64_INIT = 0x0A; public static final int DATA_DROP = 0x0C; - public static final int DATA_DROP_UNSAFE = 0x0D; - public static final int MEMORY_COPY = 0x0E; - public static final int MEMORY64_COPY_D32_S64 = 0x0F; - public static final int MEMORY64_COPY_D64_S32 = 0x10; - public static final int MEMORY64_COPY_D64_S64 = 0x11; - public static final int MEMORY_FILL = 0x12; - public static final int MEMORY64_FILL = 0x13; - - public static final int MEMORY64_SIZE = 0x14; - public static final int MEMORY64_GROW = 0x15; - public static final int TABLE_INIT = 0x16; - public static final int ELEM_DROP = 0x17; - public static final int TABLE_COPY = 0x18; - public static final int TABLE_GROW = 0x19; - public static final int TABLE_SIZE = 0x1A; - public static final int TABLE_FILL = 0x1B; + public static final int MEMORY_COPY = 0x0D; + public static final int MEMORY64_COPY_D32_S64 = 0x0E; + public static final int MEMORY64_COPY_D64_S32 = 0x0F; + public static final int MEMORY64_COPY_D64_S64 = 0x10; + public static final int MEMORY_FILL = 0x11; + public static final int MEMORY64_FILL = 0x12; + + public static final int MEMORY64_SIZE = 0x13; + public static final int MEMORY64_GROW = 0x14; + public static final int TABLE_INIT = 0x15; + public static final int ELEM_DROP = 0x16; + public static final int TABLE_COPY = 0x17; + public static final int TABLE_GROW = 0x18; + public static final int TABLE_SIZE = 0x19; + public static final int TABLE_FILL = 0x1A; // Exception opcodes - public static final int THROW = 0x1C; - public static final int THROW_REF = 0x1D; + public static final int THROW = 0x1B; + public static final int THROW_REF = 0x1C; // Typed function references opcodes - public static final int CALL_REF_U8 = 0x1E; - public static final int CALL_REF_I32 = 0x1F; - public static final int REF_AS_NON_NULL = 0x20; - public static final int BR_ON_NULL_U8 = 0x21; - public static final int BR_ON_NULL_I32 = 0x22; - public static final int BR_ON_NON_NULL_U8 = 0x23; - public static final int BR_ON_NON_NULL_I32 = 0x24; + public static final int CALL_REF_U8 = 0x1D; + public static final int CALL_REF_I32 = 0x1E; + public static final int REF_AS_NON_NULL = 0x1F; + public static final int BR_ON_NULL_U8 = 0x20; + public static final int BR_ON_NULL_I32 = 0x21; + public static final int BR_ON_NON_NULL_U8 = 0x22; + public static final int BR_ON_NON_NULL_I32 = 0x23; // Atomic opcodes public static final int ATOMIC_I32_LOAD = 0x00; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/StackEffects.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/StackEffects.java new file mode 100644 index 000000000000..e14c571832c3 --- /dev/null +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/StackEffects.java @@ -0,0 +1,394 @@ +/* + * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * The Universal Permissive License (UPL), Version 1.0 + * + * Subject to the condition set forth below, permission is hereby granted to any + * person obtaining a copy of this software, associated documentation and/or + * data (collectively the "Software"), free of charge and under any and all + * copyright rights in the Software, and any and all patent rights owned or + * freely licensable by each licensor hereunder covering either (i) the + * unmodified Software as contributed to or provided by such licensor, or (ii) + * the Larger Works (as defined below), to deal in both + * + * (a) the Software, and + * + * (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if + * one is included with the Software each a "Larger Work" to which the Software + * is contributed by such licensors), + * + * without restriction, including without limitation the rights to copy, create + * derivative works of, display, perform, and distribute the Software and make, + * use, sell, offer for sale, import, export, have made, and have sold the + * Software and the Larger Work(s), and to sublicense the foregoing rights on + * either these or other terms. + * + * This license is subject to the following condition: + * + * The above copyright notice and either this complete permission notice or at a + * minimum a reference to the UPL must be included in all copies or substantial + * portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +package org.graalvm.wasm.constants; + +import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; +import org.graalvm.wasm.nodes.WasmFunctionNode; + +/** + * The data stored in this class tells us how much the size of the operand stack changes after + * executing a single {@code v128} instruction. This is useful in {@link WasmFunctionNode}, since it + * allows us to execute {@code v128} instructions in a separate method without needing to return the + * new value of the stack pointer from that method. + */ +public final class StackEffects { + + private static final byte NO_EFFECT = 0; + private static final byte PUSH_1 = 1; + private static final byte POP_1 = -1; + private static final byte POP_2 = -2; + private static final byte POP_3 = -3; + private static final byte UNREACHABLE = Byte.MIN_VALUE; + + @CompilationFinal(dimensions = 1) private static final byte[] miscOpStackEffects = new byte[256]; + @CompilationFinal(dimensions = 1) private static final byte[] vectorOpStackEffects = new byte[256]; + + static { + miscOpStackEffects[Bytecode.I32_TRUNC_SAT_F32_S] = NO_EFFECT; + miscOpStackEffects[Bytecode.I32_TRUNC_SAT_F32_U] = NO_EFFECT; + miscOpStackEffects[Bytecode.I32_TRUNC_SAT_F64_S] = NO_EFFECT; + miscOpStackEffects[Bytecode.I32_TRUNC_SAT_F64_U] = NO_EFFECT; + miscOpStackEffects[Bytecode.I64_TRUNC_SAT_F32_S] = NO_EFFECT; + miscOpStackEffects[Bytecode.I64_TRUNC_SAT_F32_U] = NO_EFFECT; + miscOpStackEffects[Bytecode.I64_TRUNC_SAT_F64_S] = NO_EFFECT; + miscOpStackEffects[Bytecode.I64_TRUNC_SAT_F64_U] = NO_EFFECT; + miscOpStackEffects[Bytecode.MEMORY_INIT] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_INIT] = POP_3; + miscOpStackEffects[Bytecode.DATA_DROP] = NO_EFFECT; + miscOpStackEffects[Bytecode.MEMORY_COPY] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_COPY_D32_S64] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_COPY_D64_S32] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_COPY_D64_S64] = POP_3; + miscOpStackEffects[Bytecode.MEMORY_FILL] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_FILL] = POP_3; + miscOpStackEffects[Bytecode.MEMORY64_SIZE] = PUSH_1; + miscOpStackEffects[Bytecode.MEMORY64_GROW] = NO_EFFECT; + miscOpStackEffects[Bytecode.TABLE_INIT] = POP_3; + miscOpStackEffects[Bytecode.ELEM_DROP] = NO_EFFECT; + miscOpStackEffects[Bytecode.TABLE_COPY] = POP_3; + miscOpStackEffects[Bytecode.TABLE_GROW] = POP_1; + miscOpStackEffects[Bytecode.TABLE_SIZE] = PUSH_1; + miscOpStackEffects[Bytecode.TABLE_FILL] = POP_3; + miscOpStackEffects[Bytecode.THROW] = UNREACHABLE; // unused, because stack effect is + // followed by throw + miscOpStackEffects[Bytecode.THROW_REF] = UNREACHABLE; // unused, because stack effect is + // followed by throw + miscOpStackEffects[Bytecode.CALL_REF_U8] = UNREACHABLE; // unused, because stack effect is + // variable + miscOpStackEffects[Bytecode.CALL_REF_I32] = UNREACHABLE; // unused, because stack effect is + // variable + miscOpStackEffects[Bytecode.REF_AS_NON_NULL] = NO_EFFECT; + miscOpStackEffects[Bytecode.BR_ON_NULL_U8] = UNREACHABLE; // unused, because stack effect is + // dynamic + miscOpStackEffects[Bytecode.BR_ON_NULL_I32] = UNREACHABLE; // unused, because stack effect + // is dynamic + miscOpStackEffects[Bytecode.BR_ON_NON_NULL_U8] = UNREACHABLE; // unused, because stack + // effect is dynamic + miscOpStackEffects[Bytecode.BR_ON_NON_NULL_I32] = UNREACHABLE; // unused, because stack + // effect is dynamic + + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD8X8_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD8X8_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD16X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD16X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD32X2_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD32X2_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD8_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD16_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD32_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD64_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD32_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD64_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_STORE] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD8_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD16_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD32_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_LOAD64_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_STORE8_LANE] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_STORE16_LANE] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_STORE32_LANE] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_STORE64_LANE] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_CONST] = PUSH_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SHUFFLE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_EXTRACT_LANE_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_EXTRACT_LANE_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTRACT_LANE_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTRACT_LANE_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTRACT_LANE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTRACT_LANE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_EXTRACT_LANE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_EXTRACT_LANE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_REPLACE_LANE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SWIZZLE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_SPLAT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_LT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_LT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_GT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_GT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_LE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_LE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_GE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_GE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_LT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_LT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_GT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_GT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_LE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_LE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_GE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_GE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_LT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_LT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_GT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_GT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_LE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_LE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_GE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_GE_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_LT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_GT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_LE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_GE_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_LT] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_GT] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_LE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_GE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_EQ] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_NE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_LT] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_GT] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_LE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_GE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_NOT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_V128_AND] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_ANDNOT] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_OR] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_XOR] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_V128_BITSELECT] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_V128_ANY_TRUE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_POPCNT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_ALL_TRUE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_BITMASK] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_NARROW_I16X8_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_NARROW_I16X8_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SHL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SHR_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SHR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_ADD_SAT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_ADD_SAT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SUB_SAT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_SUB_SAT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_MIN_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_MIN_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_MAX_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_MAX_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_AVGR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTADD_PAIRWISE_I8X16_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTADD_PAIRWISE_I8X16_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_Q15MULR_SAT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_ALL_TRUE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_BITMASK] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_NARROW_I32X4_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_NARROW_I32X4_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_LOW_I8X16_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_HIGH_I8X16_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_LOW_I8X16_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_HIGH_I8X16_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SHL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SHR_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SHR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_ADD_SAT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_ADD_SAT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SUB_SAT_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_SUB_SAT_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_MUL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_MIN_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_MIN_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_MAX_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_MAX_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_AVGR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_HIGH_I8X16_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_HIGH_I8X16_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTADD_PAIRWISE_I16X8_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTADD_PAIRWISE_I16X8_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_ALL_TRUE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_BITMASK] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_LOW_I16X8_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_HIGH_I16X8_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_LOW_I16X8_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_HIGH_I16X8_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_SHL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_SHR_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_SHR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_MUL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_MIN_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_MIN_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_MAX_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_MAX_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_DOT_I16X8_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_LOW_I16X8_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_HIGH_I16X8_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_LOW_I16X8_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_HIGH_I16X8_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_ALL_TRUE] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_BITMASK] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_LOW_I32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_HIGH_I32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_LOW_I32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_HIGH_I32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_SHL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_SHR_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_SHR_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_MUL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_LOW_I32X4_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_HIGH_I32X4_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_LOW_I32X4_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_HIGH_I32X4_U] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_CEIL] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_FLOOR] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_TRUNC] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_NEAREST] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_SQRT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_MUL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_DIV] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_MIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_MAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_PMIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_PMAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_CEIL] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_FLOOR] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_TRUNC] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_NEAREST] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_ABS] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_NEG] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_SQRT] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_ADD] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_SUB] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_MUL] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_DIV] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_MIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_MAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_PMIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_PMAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_CONVERT_I32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_CONVERT_I32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F64X2_S_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F64X2_U_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_CONVERT_LOW_I32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_CONVERT_LOW_I32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_DEMOTE_F64X2_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_PROMOTE_LOW_F32X4] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_RELAXED_SWIZZLE] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F32X4_S] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F32X4_U] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F64X2_S_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F64X2_U_ZERO] = NO_EFFECT; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MADD] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_NMADD] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MADD] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_NMADD] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_I8X16_RELAXED_LANESELECT] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_LANESELECT] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_LANESELECT] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_I64X2_RELAXED_LANESELECT] = POP_2; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MIN] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MAX] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_Q15MULR_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_DOT_I8X16_I7X16_S] = POP_1; + vectorOpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_DOT_I8X16_I7X16_ADD_S] = POP_2; + } + + private StackEffects() { + } + + /** + * Indicates by how much the stack grows (positive return value) or shrinks (negative return + * value) after executing the {@link Bytecode#MISC}-prefixed bytecode {@code miscOpcode}. + * + * @param miscOpcode the {@code MISC} bytecode being executed + * @return the difference between the stack size after executing the bytecode and before + * executing the bytecode + */ + public static byte getMiscOpStackEffect(int miscOpcode) { + assert miscOpcode < 256; + return miscOpStackEffects[miscOpcode]; + } + + /** + * Indicates by how much the stack grows (positive return value) or shrinks (negative return + * value) after executing the {@link Bytecode#MISC}-prefixed bytecode {@code vectorOpcode}. + * + * @param vectorOpcode the {@code VECTOR} bytecode being executed + * @return the difference between the stack size after executing the bytecode and before + * executing the bytecode + */ + public static byte getVectorOpStackEffect(int vectorOpcode) { + assert vectorOpcode < 256; + return vectorOpStackEffects[vectorOpcode]; + } +} diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Vector128OpStackEffects.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Vector128OpStackEffects.java deleted file mode 100644 index d15a11f796a1..000000000000 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Vector128OpStackEffects.java +++ /dev/null @@ -1,335 +0,0 @@ -/* - * Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved. - * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. - * - * The Universal Permissive License (UPL), Version 1.0 - * - * Subject to the condition set forth below, permission is hereby granted to any - * person obtaining a copy of this software, associated documentation and/or - * data (collectively the "Software"), free of charge and under any and all - * copyright rights in the Software, and any and all patent rights owned or - * freely licensable by each licensor hereunder covering either (i) the - * unmodified Software as contributed to or provided by such licensor, or (ii) - * the Larger Works (as defined below), to deal in both - * - * (a) the Software, and - * - * (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if - * one is included with the Software each a "Larger Work" to which the Software - * is contributed by such licensors), - * - * without restriction, including without limitation the rights to copy, create - * derivative works of, display, perform, and distribute the Software and make, - * use, sell, offer for sale, import, export, have made, and have sold the - * Software and the Larger Work(s), and to sublicense the foregoing rights on - * either these or other terms. - * - * This license is subject to the following condition: - * - * The above copyright notice and either this complete permission notice or at a - * minimum a reference to the UPL must be included in all copies or substantial - * portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -package org.graalvm.wasm.constants; - -import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; -import org.graalvm.wasm.nodes.WasmFunctionNode; - -/** - * The data stored in this class tells us how much the size of the operand stack changes after - * executing a single {@code v128} instruction. This is useful in {@link WasmFunctionNode}, since it - * allows us to execute {@code v128} instructions in a separate method without needing to return the - * new value of the stack pointer from that method. - */ -public final class Vector128OpStackEffects { - - private static final byte NO_EFFECT = 0; - private static final byte PUSH_1 = 1; - private static final byte POP_1 = -1; - private static final byte POP_2 = -2; - - @CompilationFinal(dimensions = 1) private static final byte[] vector128OpStackEffects = new byte[256]; - - static { - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD8X8_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD8X8_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD16X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD16X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD32X2_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD32X2_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD8_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD16_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD32_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD64_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD32_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD64_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_STORE] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD8_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD16_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD32_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_LOAD64_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_STORE8_LANE] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_STORE16_LANE] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_STORE32_LANE] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_STORE64_LANE] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_CONST] = PUSH_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SHUFFLE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_EXTRACT_LANE_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_EXTRACT_LANE_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTRACT_LANE_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTRACT_LANE_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTRACT_LANE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTRACT_LANE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_EXTRACT_LANE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_EXTRACT_LANE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_REPLACE_LANE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SWIZZLE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_SPLAT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_LT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_LT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_GT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_GT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_LE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_LE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_GE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_GE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_LT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_LT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_GT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_GT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_LE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_LE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_GE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_GE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_LT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_LT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_GT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_GT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_LE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_LE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_GE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_GE_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_LT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_GT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_LE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_GE_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_LT] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_GT] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_LE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_GE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_EQ] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_NE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_LT] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_GT] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_LE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_GE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_NOT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_V128_AND] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_ANDNOT] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_OR] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_XOR] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_V128_BITSELECT] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_V128_ANY_TRUE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_POPCNT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_ALL_TRUE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_BITMASK] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_NARROW_I16X8_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_NARROW_I16X8_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SHL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SHR_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SHR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_ADD_SAT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_ADD_SAT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SUB_SAT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_SUB_SAT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_MIN_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_MIN_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_MAX_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_MAX_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_AVGR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTADD_PAIRWISE_I8X16_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTADD_PAIRWISE_I8X16_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_Q15MULR_SAT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_ALL_TRUE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_BITMASK] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_NARROW_I32X4_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_NARROW_I32X4_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_LOW_I8X16_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_HIGH_I8X16_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_LOW_I8X16_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTEND_HIGH_I8X16_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SHL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SHR_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SHR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_ADD_SAT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_ADD_SAT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SUB_SAT_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_SUB_SAT_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_MUL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_MIN_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_MIN_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_MAX_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_MAX_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_AVGR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_HIGH_I8X16_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_LOW_I8X16_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_EXTMUL_HIGH_I8X16_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTADD_PAIRWISE_I16X8_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTADD_PAIRWISE_I16X8_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_ALL_TRUE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_BITMASK] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_LOW_I16X8_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_HIGH_I16X8_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_LOW_I16X8_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTEND_HIGH_I16X8_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_SHL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_SHR_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_SHR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_MUL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_MIN_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_MIN_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_MAX_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_MAX_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_DOT_I16X8_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_LOW_I16X8_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_HIGH_I16X8_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_LOW_I16X8_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_EXTMUL_HIGH_I16X8_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_ALL_TRUE] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_BITMASK] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_LOW_I32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_HIGH_I32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_LOW_I32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTEND_HIGH_I32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_SHL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_SHR_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_SHR_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_MUL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_LOW_I32X4_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_HIGH_I32X4_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_LOW_I32X4_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_EXTMUL_HIGH_I32X4_U] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_CEIL] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_FLOOR] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_TRUNC] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_NEAREST] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_SQRT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_MUL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_DIV] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_MIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_MAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_PMIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_PMAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_CEIL] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_FLOOR] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_TRUNC] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_NEAREST] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_ABS] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_NEG] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_SQRT] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_ADD] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_SUB] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_MUL] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_DIV] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_MIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_MAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_PMIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_PMAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_CONVERT_I32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_CONVERT_I32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F64X2_S_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_TRUNC_SAT_F64X2_U_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_CONVERT_LOW_I32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_CONVERT_LOW_I32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_DEMOTE_F64X2_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_PROMOTE_LOW_F32X4] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_RELAXED_SWIZZLE] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F32X4_S] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F32X4_U] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F64X2_S_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_TRUNC_F64X2_U_ZERO] = NO_EFFECT; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MADD] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_NMADD] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MADD] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_NMADD] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_I8X16_RELAXED_LANESELECT] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_LANESELECT] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_LANESELECT] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_I64X2_RELAXED_LANESELECT] = POP_2; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F32X4_RELAXED_MAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MIN] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_F64X2_RELAXED_MAX] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_Q15MULR_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I16X8_RELAXED_DOT_I8X16_I7X16_S] = POP_1; - vector128OpStackEffects[Bytecode.VECTOR_I32X4_RELAXED_DOT_I8X16_I7X16_ADD_S] = POP_2; - } - - private Vector128OpStackEffects() { - } - - /** - * Indicates by how much the stack grows (positive return value) or shrinks (negative return - * value) after executing the {@code v128} instruction with opcode {@code vectorOpcode}. - * - * @param vectorOpcode the {@code v128} instruction being executed - * @return the difference between the stack size after executing the instruction and before - * executing the instruction - */ - public static byte getVector128OpStackEffect(int vectorOpcode) { - assert vectorOpcode < 256; - return vector128OpStackEffects[vectorOpcode]; - } -} diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index 78e35f703c1a..4f7b3a9d537e 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -84,7 +84,7 @@ import org.graalvm.wasm.constants.Bytecode; import org.graalvm.wasm.constants.BytecodeBitEncoding; import org.graalvm.wasm.constants.ExceptionHandlerType; -import org.graalvm.wasm.constants.Vector128OpStackEffects; +import org.graalvm.wasm.constants.StackEffects; import org.graalvm.wasm.exception.Failure; import org.graalvm.wasm.exception.WasmException; import org.graalvm.wasm.exception.WasmRuntimeException; @@ -1527,142 +1527,6 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i offset++; CompilerAsserts.partialEvaluationConstant(miscOpcode); switch (miscOpcode) { - case Bytecode.I32_TRUNC_SAT_F32_S: - i32_trunc_sat_f32_s(frame, stackPointer); - break; - case Bytecode.I32_TRUNC_SAT_F32_U: - i32_trunc_sat_f32_u(frame, stackPointer); - break; - case Bytecode.I32_TRUNC_SAT_F64_S: - i32_trunc_sat_f64_s(frame, stackPointer); - break; - case Bytecode.I32_TRUNC_SAT_F64_U: - i32_trunc_sat_f64_u(frame, stackPointer); - break; - case Bytecode.I64_TRUNC_SAT_F32_S: - i64_trunc_sat_f32_s(frame, stackPointer); - break; - case Bytecode.I64_TRUNC_SAT_F32_U: - i64_trunc_sat_f32_u(frame, stackPointer); - break; - case Bytecode.I64_TRUNC_SAT_F64_S: - i64_trunc_sat_f64_s(frame, stackPointer); - break; - case Bytecode.I64_TRUNC_SAT_F64_U: - i64_trunc_sat_f64_u(frame, stackPointer); - break; - case Bytecode.MEMORY_INIT: - case Bytecode.MEMORY64_INIT: { - final int dataIndex = rawPeekI32(bytecode, offset); - final int memoryIndex = rawPeekI32(bytecode, offset + 4); - executeMemoryInit(instance, frame, stackPointer, miscOpcode, memoryIndex, dataIndex); - stackPointer -= 3; - offset += 8; - break; - } - case Bytecode.DATA_DROP: { - final int dataIndex = rawPeekI32(bytecode, offset); - data_drop(instance, dataIndex); - offset += 4; - break; - } - case Bytecode.MEMORY_COPY: - case Bytecode.MEMORY64_COPY_D64_S64: - case Bytecode.MEMORY64_COPY_D64_S32: - case Bytecode.MEMORY64_COPY_D32_S64: { - final int destMemoryIndex = rawPeekI32(bytecode, offset); - final int srcMemoryIndex = rawPeekI32(bytecode, offset + 4); - executeMemoryCopy(instance, frame, stackPointer, miscOpcode, destMemoryIndex, srcMemoryIndex); - stackPointer -= 3; - offset += 8; - break; - } - case Bytecode.MEMORY_FILL: - case Bytecode.MEMORY64_FILL: { - final int memoryIndex = rawPeekI32(bytecode, offset); - executeMemoryFill(instance, frame, stackPointer, miscOpcode, memoryIndex); - stackPointer -= 3; - offset += 4; - break; - } - case Bytecode.TABLE_INIT: { - final int elementIndex = rawPeekI32(bytecode, offset); - final int tableIndex = rawPeekI32(bytecode, offset + 4); - - final int n = popInt(frame, stackPointer - 1); - final int src = popInt(frame, stackPointer - 2); - final int dst = popInt(frame, stackPointer - 3); - table_init(instance, n, src, dst, tableIndex, elementIndex); - stackPointer -= 3; - offset += 8; - break; - } - case Bytecode.ELEM_DROP: { - final int elementIndex = rawPeekI32(bytecode, offset); - instance.dropElemInstance(elementIndex); - offset += 4; - break; - } - case Bytecode.TABLE_COPY: { - final int srcIndex = rawPeekI32(bytecode, offset); - final int dstIndex = rawPeekI32(bytecode, offset + 4); - - final int n = popInt(frame, stackPointer - 1); - final int src = popInt(frame, stackPointer - 2); - final int dst = popInt(frame, stackPointer - 3); - table_copy(instance, n, src, dst, srcIndex, dstIndex); - stackPointer -= 3; - offset += 8; - break; - } - case Bytecode.TABLE_GROW: { - final int tableIndex = rawPeekI32(bytecode, offset); - - final int n = popInt(frame, stackPointer - 1); - final Object val = popReference(frame, stackPointer - 2); - - final int res = table_grow(instance, n, val, tableIndex); - pushInt(frame, stackPointer - 2, res); - stackPointer--; - offset += 4; - break; - } - case Bytecode.TABLE_SIZE: { - final int tableIndex = rawPeekI32(bytecode, offset); - table_size(instance, frame, stackPointer, tableIndex); - stackPointer++; - offset += 4; - break; - } - case Bytecode.TABLE_FILL: { - final int tableIndex = rawPeekI32(bytecode, offset); - - final int n = popInt(frame, stackPointer - 1); - final Object val = popReference(frame, stackPointer - 2); - final int i = popInt(frame, stackPointer - 3); - table_fill(instance, n, val, i, tableIndex); - stackPointer -= 3; - offset += 4; - break; - } - case Bytecode.MEMORY64_SIZE: { - final int memoryIndex = rawPeekI32(bytecode, offset); - offset += 4; - final WasmMemory memory = memory(instance, memoryIndex); - long pageSize = memoryLib(memoryIndex).size(memory); - pushLong(frame, stackPointer, pageSize); - stackPointer++; - break; - } - case Bytecode.MEMORY64_GROW: { - final int memoryIndex = rawPeekI32(bytecode, offset); - offset += 4; - final WasmMemory memory = memory(instance, memoryIndex); - long extraSize = popLong(frame, stackPointer - 1); - long previousSize = memoryLib(memoryIndex).grow(memory, extraSize); - pushLong(frame, stackPointer - 1, previousSize); - break; - } case Bytecode.THROW: { codeEntry.exceptionBranch(); final int tagIndex = rawPeekI32(bytecode, offset); @@ -1729,15 +1593,6 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i CompilerAsserts.partialEvaluationConstant(stackPointer); break; } - case Bytecode.REF_AS_NON_NULL: { - Object reference = popReference(frame, stackPointer - 1); - if (reference == WasmConstant.NULL) { - enterErrorBranch(); - throw WasmException.format(Failure.NULL_REFERENCE, this, "Function reference is null"); - } - pushReference(frame, stackPointer - 1, reference); - break; - } case Bytecode.BR_ON_NULL_U8: { Object reference = popReference(frame, --stackPointer); if (profileCondition(bytecode, offset + 1, reference == WasmConstant.NULL)) { @@ -1788,8 +1643,11 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i } break; } - default: - throw CompilerDirectives.shouldNotReachHere(); + default: { + offset = executeMisc(instance, frame, offset, stackPointer, miscOpcode); + stackPointer += StackEffects.getMiscOpStackEffect(miscOpcode); + break; + } } break; } @@ -1825,7 +1683,7 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i offset++; CompilerAsserts.partialEvaluationConstant(vectorOpcode); offset = executeVector(instance, frame, offset, stackPointer, vectorOpcode); - stackPointer += Vector128OpStackEffects.getVector128OpStackEffect(vectorOpcode); + stackPointer += StackEffects.getVectorOpStackEffect(vectorOpcode); break; } case Bytecode.NOTIFY: { @@ -2271,6 +2129,165 @@ private void executeMemoryFill(WasmInstance instance, VirtualFrame frame, int st memory_fill(instance, n, val, dst, memoryIndex); } + private int executeMisc(WasmInstance instance, VirtualFrame frame, int startingOffset, int startingStackPointer, int miscOpcode) { + int offset = startingOffset; + int stackPointer = startingStackPointer; + + CompilerAsserts.partialEvaluationConstant(miscOpcode); + switch (miscOpcode) { + case Bytecode.I32_TRUNC_SAT_F32_S: + i32_trunc_sat_f32_s(frame, stackPointer); + break; + case Bytecode.I32_TRUNC_SAT_F32_U: + i32_trunc_sat_f32_u(frame, stackPointer); + break; + case Bytecode.I32_TRUNC_SAT_F64_S: + i32_trunc_sat_f64_s(frame, stackPointer); + break; + case Bytecode.I32_TRUNC_SAT_F64_U: + i32_trunc_sat_f64_u(frame, stackPointer); + break; + case Bytecode.I64_TRUNC_SAT_F32_S: + i64_trunc_sat_f32_s(frame, stackPointer); + break; + case Bytecode.I64_TRUNC_SAT_F32_U: + i64_trunc_sat_f32_u(frame, stackPointer); + break; + case Bytecode.I64_TRUNC_SAT_F64_S: + i64_trunc_sat_f64_s(frame, stackPointer); + break; + case Bytecode.I64_TRUNC_SAT_F64_U: + i64_trunc_sat_f64_u(frame, stackPointer); + break; + case Bytecode.MEMORY_INIT: + case Bytecode.MEMORY64_INIT: { + final int dataIndex = rawPeekI32(bytecode, offset); + final int memoryIndex = rawPeekI32(bytecode, offset + 4); + executeMemoryInit(instance, frame, stackPointer, miscOpcode, memoryIndex, dataIndex); + stackPointer -= 3; + offset += 8; + break; + } + case Bytecode.DATA_DROP: { + final int dataIndex = rawPeekI32(bytecode, offset); + data_drop(instance, dataIndex); + offset += 4; + break; + } + case Bytecode.MEMORY_COPY: + case Bytecode.MEMORY64_COPY_D64_S64: + case Bytecode.MEMORY64_COPY_D64_S32: + case Bytecode.MEMORY64_COPY_D32_S64: { + final int destMemoryIndex = rawPeekI32(bytecode, offset); + final int srcMemoryIndex = rawPeekI32(bytecode, offset + 4); + executeMemoryCopy(instance, frame, stackPointer, miscOpcode, destMemoryIndex, srcMemoryIndex); + stackPointer -= 3; + offset += 8; + break; + } + case Bytecode.MEMORY_FILL: + case Bytecode.MEMORY64_FILL: { + final int memoryIndex = rawPeekI32(bytecode, offset); + executeMemoryFill(instance, frame, stackPointer, miscOpcode, memoryIndex); + stackPointer -= 3; + offset += 4; + break; + } + case Bytecode.TABLE_INIT: { + final int elementIndex = rawPeekI32(bytecode, offset); + final int tableIndex = rawPeekI32(bytecode, offset + 4); + + final int n = popInt(frame, stackPointer - 1); + final int src = popInt(frame, stackPointer - 2); + final int dst = popInt(frame, stackPointer - 3); + table_init(instance, n, src, dst, tableIndex, elementIndex); + stackPointer -= 3; + offset += 8; + break; + } + case Bytecode.ELEM_DROP: { + final int elementIndex = rawPeekI32(bytecode, offset); + instance.dropElemInstance(elementIndex); + offset += 4; + break; + } + case Bytecode.TABLE_COPY: { + final int srcIndex = rawPeekI32(bytecode, offset); + final int dstIndex = rawPeekI32(bytecode, offset + 4); + + final int n = popInt(frame, stackPointer - 1); + final int src = popInt(frame, stackPointer - 2); + final int dst = popInt(frame, stackPointer - 3); + table_copy(instance, n, src, dst, srcIndex, dstIndex); + stackPointer -= 3; + offset += 8; + break; + } + case Bytecode.TABLE_GROW: { + final int tableIndex = rawPeekI32(bytecode, offset); + + final int n = popInt(frame, stackPointer - 1); + final Object val = popReference(frame, stackPointer - 2); + + final int res = table_grow(instance, n, val, tableIndex); + pushInt(frame, stackPointer - 2, res); + stackPointer--; + offset += 4; + break; + } + case Bytecode.TABLE_SIZE: { + final int tableIndex = rawPeekI32(bytecode, offset); + table_size(instance, frame, stackPointer, tableIndex); + stackPointer++; + offset += 4; + break; + } + case Bytecode.TABLE_FILL: { + final int tableIndex = rawPeekI32(bytecode, offset); + + final int n = popInt(frame, stackPointer - 1); + final Object val = popReference(frame, stackPointer - 2); + final int i = popInt(frame, stackPointer - 3); + table_fill(instance, n, val, i, tableIndex); + stackPointer -= 3; + offset += 4; + break; + } + case Bytecode.MEMORY64_SIZE: { + final int memoryIndex = rawPeekI32(bytecode, offset); + offset += 4; + final WasmMemory memory = memory(instance, memoryIndex); + long pageSize = memoryLib(memoryIndex).size(memory); + pushLong(frame, stackPointer, pageSize); + stackPointer++; + break; + } + case Bytecode.MEMORY64_GROW: { + final int memoryIndex = rawPeekI32(bytecode, offset); + offset += 4; + final WasmMemory memory = memory(instance, memoryIndex); + long extraSize = popLong(frame, stackPointer - 1); + long previousSize = memoryLib(memoryIndex).grow(memory, extraSize); + pushLong(frame, stackPointer - 1, previousSize); + break; + } + case Bytecode.REF_AS_NON_NULL: { + Object reference = popReference(frame, stackPointer - 1); + if (reference == WasmConstant.NULL) { + enterErrorBranch(); + throw WasmException.format(Failure.NULL_REFERENCE, this, "Function reference is null"); + } + pushReference(frame, stackPointer - 1, reference); + break; + } + default: + throw CompilerDirectives.shouldNotReachHere(); + } + + assert stackPointer - startingStackPointer == StackEffects.getMiscOpStackEffect(miscOpcode); + return offset; + } + private int executeAtomic(VirtualFrame frame, int stackPointer, int opcode, WasmMemory memory, WasmMemoryLibrary memoryLib, long memOffset, int indexType64) { switch (opcode) { case Bytecode.ATOMIC_NOTIFY: @@ -3329,7 +3346,7 @@ private int executeVector(WasmInstance instance, VirtualFrame frame, int startin throw CompilerDirectives.shouldNotReachHere(); } - assert stackPointer - startingStackPointer == Vector128OpStackEffects.getVector128OpStackEffect(vectorOpcode); + assert stackPointer - startingStackPointer == StackEffects.getVectorOpStackEffect(vectorOpcode); return offset; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java index cf96af2e9e08..978a5250af39 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java @@ -828,7 +828,6 @@ private static List readCallNodes(byte[] bytecode, int startOffset, in case Bytecode.MEMORY64_SIZE: case Bytecode.MEMORY64_GROW: case Bytecode.DATA_DROP: - case Bytecode.DATA_DROP_UNSAFE: case Bytecode.ELEM_DROP: case Bytecode.TABLE_GROW: case Bytecode.TABLE_SIZE: From a26351dd77849d35ab9913270c435eea2e4ddcc7 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Sun, 19 Oct 2025 23:19:46 +0200 Subject: [PATCH 30/40] Cleanup and restored fixes after reverting back to closed types --- .../test/suites/bytecode/BytecodeSuite.java | 2 +- .../ReferenceTypesValidationSuite.java | 1 - .../suites/validation/ValidationSuite.java | 6 +- .../src/org/graalvm/wasm/BinaryParser.java | 23 +- .../src/org/graalvm/wasm/Linker.java | 8 +- .../src/org/graalvm/wasm/SymbolTable.java | 262 ++++++++---------- .../src/org/graalvm/wasm/WasmTable.java | 9 +- .../src/org/graalvm/wasm/WasmType.java | 16 +- .../wasm/api/ExecuteHostFunctionNode.java | 6 +- .../wasm/api/InteropCallAdapterNode.java | 1 - .../src/org/graalvm/wasm/api/ValueType.java | 2 - .../src/org/graalvm/wasm/api/WebAssembly.java | 6 +- .../org/graalvm/wasm/globals/WasmGlobal.java | 16 +- .../graalvm/wasm/nodes/WasmFunctionNode.java | 2 +- .../wasm/parser/validation/IfFrame.java | 2 +- .../wasm/parser/validation/ParserState.java | 6 +- .../parser/validation/ValidationErrors.java | 5 + 17 files changed, 168 insertions(+), 205 deletions(-) diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java index fab3b7403549..2eca736fdadd 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/bytecode/BytecodeSuite.java @@ -689,7 +689,7 @@ public void testElemHeaderExternref() { @Test public void testElemHeaderExnref() { - test(b -> b.addElemHeader(SegmentMode.ACTIVE, 8, WasmType.EXNREF_TYPE, 0, null, -1), new byte[]{0x40, 0x30, 0x08}); + test(b -> b.addElemHeader(SegmentMode.ACTIVE, 8, WasmType.EXNREF_TYPE, 0, null, -1), new byte[]{0x40, 0x10, WasmType.EXNREF_TYPE, 0x08}); } @Test diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java index 8307423ee8fc..31dcf9980adc 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ReferenceTypesValidationSuite.java @@ -537,7 +537,6 @@ public void testMultipleTables() throws IOException { // (table 1 1 exnref) final byte[] binary = newBuilder().addTable(1, 1, WasmType.FUNCREF_TYPE).addTable(1, 1, WasmType.EXTERNREF_TYPE).addTable(1, 1, WasmType.EXNREF_TYPE).build(); runParserTest(binary, options -> options.option("wasm.Exceptions", "true"), Context::eval); - runParserTest(binary, Context::eval); } @Test diff --git a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java index 03c6ba46e99d..33de86503d32 100644 --- a/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java +++ b/wasm/src/org.graalvm.wasm.test/src/org/graalvm/wasm/test/suites/validation/ValidationSuite.java @@ -220,13 +220,13 @@ public static Collection data() { // The type `C.types[x]` must be defined in the context. stringCase( "Function - invalid type index", - "unknown type: 1 should be < 1", + "Function type variable 1 out of range. (max 0)", "(type (func (result i32))) (func (export \"f\") (type 1))", Failure.Type.INVALID), stringCase( "Function - invalid type index", - "unknown type: 4294967254 should be < 1", - "(type (func (result i32))) (func (export \"f\") (type 4294967254))", + "Function type variable 1073741823 out of range. (max 0)", + "(type (func (result i32))) (func (export \"f\") (type 1073741823))", Failure.Type.INVALID), // Under the context `C'`, the expression `express` must be valid with type diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index 50246cbad1f8..08834a1b1303 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -49,6 +49,7 @@ import static org.graalvm.wasm.Assert.assertUnsignedIntLessOrEqual; import static org.graalvm.wasm.Assert.assertUnsignedLongLessOrEqual; import static org.graalvm.wasm.Assert.fail; +import static org.graalvm.wasm.WasmType.BOT; import static org.graalvm.wasm.WasmType.EXNREF_TYPE; import static org.graalvm.wasm.WasmType.EXN_HEAPTYPE; import static org.graalvm.wasm.WasmType.EXTERNREF_TYPE; @@ -817,7 +818,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn final int tableIndex = readTableIndex(); // Pop the function index to call state.popChecked(I32_TYPE); - Assert.assertTrue(module.matches(FUNCREF_TYPE, module.tableElementType(tableIndex)), Failure.TYPE_MISMATCH); + Assert.assertTrue(module.matchesType(FUNCREF_TYPE, module.tableElementType(tableIndex)), Failure.TYPE_MISMATCH); // Pop parameters for (int i = module.functionTypeParamCount(expectedFunctionTypeIndex) - 1; i >= 0; --i) { @@ -850,8 +851,8 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn final int t1 = state.pop(); // first operand final int t2 = state.pop(); // second operand assertTrue((WasmType.isNumberType(t1) || WasmType.isVectorType(t1)) && (WasmType.isNumberType(t2) || WasmType.isVectorType(t2)), Failure.TYPE_MISMATCH); - assertTrue(t1 == t2 || WasmType.isBottomType(t1) || WasmType.isBottomType(t2), Failure.TYPE_MISMATCH); - final int t = WasmType.isBottomType(t1) ? t2 : t1; + assertTrue(t1 == t2 || t1 == BOT || t2 == BOT, Failure.TYPE_MISMATCH); + final int t = t1 == BOT ? t2 : t1; state.push(t); if (WasmType.isNumberType(t)) { state.addSelectInstruction(Bytecode.SELECT); @@ -1655,7 +1656,7 @@ private void readNumericInstructions(ParserState state, int opcode) { final int destinationElementType = module.tableElementType(destinationTableIndex); final int sourceTableIndex = readTableIndex(); final int sourceElementType = module.tableElementType(sourceTableIndex); - Assert.assertTrue(module.matches(destinationElementType, sourceElementType), Failure.TYPE_MISMATCH); + Assert.assertTrue(module.matchesType(destinationElementType, sourceElementType), Failure.TYPE_MISMATCH); state.popChecked(I32_TYPE); state.popChecked(I32_TYPE); state.popChecked(I32_TYPE); @@ -2838,7 +2839,7 @@ private long[] readFunctionIndices(int elemType) { final int functionIndex = readDeclaredFunctionIndex(); module.addFunctionReference(functionIndex); final int functionReferenceType = WasmType.withNullable(false, module.function(functionIndex).typeIndex()); - Assert.assertTrue(module.matches(elemType, functionReferenceType), Failure.TYPE_MISMATCH); + Assert.assertTrue(module.matchesType(elemType, functionReferenceType), Failure.TYPE_MISMATCH); functionIndices[index] = ((long) ELEM_ITEM_REF_FUNC_ENTRY_PREFIX << 32) | functionIndex; } return functionIndices; @@ -2877,21 +2878,21 @@ private long[] readElemExpressions(int elemType) { case Instructions.REF_NULL: final int heapType = readHeapType(); final int nullableReferenceType = WasmType.withNullable(true, heapType); - Assert.assertTrue(module.matches(elemType, nullableReferenceType), "Invalid ref.null type: 0x%02X", Failure.TYPE_MISMATCH); + Assert.assertTrue(module.matchesType(elemType, nullableReferenceType), "Invalid ref.null type: 0x%02X", Failure.TYPE_MISMATCH); elements[index] = ((long) ELEM_ITEM_REF_NULL_ENTRY_PREFIX << 32); break; case Instructions.REF_FUNC: final int functionIndex = readDeclaredFunctionIndex(); module.addFunctionReference(functionIndex); final int functionReferenceType = WasmType.withNullable(false, module.function(functionIndex).typeIndex()); - Assert.assertTrue(module.matches(elemType, functionReferenceType), "Invalid element type: 0x%02X", Failure.TYPE_MISMATCH); + Assert.assertTrue(module.matchesType(elemType, functionReferenceType), "Invalid element type: 0x%02X", Failure.TYPE_MISMATCH); elements[index] = ((long) ELEM_ITEM_REF_FUNC_ENTRY_PREFIX << 32) | functionIndex; break; case Instructions.GLOBAL_GET: final int globalIndex = readGlobalIndex(); assertIntEqual(module.globalMutability(globalIndex), GlobalModifier.CONSTANT, Failure.CONSTANT_EXPRESSION_REQUIRED); final int valueType = module.globalValueType(globalIndex); - Assert.assertTrue(module.matches(elemType, valueType), Failure.TYPE_MISMATCH); + Assert.assertTrue(module.matchesType(elemType, valueType), Failure.TYPE_MISMATCH); elements[index] = ((long) ELEM_ITEM_GLOBAL_GET_ENTRY_PREFIX << 32) | globalIndex; break; case Instructions.VECTOR: @@ -3335,7 +3336,11 @@ private int readDeclaredFunctionIndex() { */ public void checkFunctionTypeExists(int typeIndex) { if (compareUnsigned(typeIndex, module.typeCount()) >= 0) { - throw ValidationErrors.createMissingFunctionType(typeIndex, module.tableCount() - 1); + if (module.typeCount() > 0) { + throw ValidationErrors.createMissingFunctionType(typeIndex, module.typeCount() - 1); + } else { + throw ValidationErrors.createMissingFunctionType(typeIndex); + } } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index b5f315993e18..8b4064102567 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -362,15 +362,15 @@ void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescripto externalGlobal = importedInstance.externalGlobal(exportedGlobalIndex); } - if (instance.symbolTable().makeClosedType(valueType).isSupertypeOf(exportedClosedValueType)) { + if (!instance.symbolTable().closedTypeOf(valueType).isSupertypeOf(exportedClosedValueType)) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE, "Global variable '" + importedGlobalName + "' is imported into module '" + instance.name() + "' with the type " + WasmType.toString(valueType) + ", " + - "'but it was exported in the module '" + importedModuleName + "' with the type " + WasmType.toString(exportedValueType) + "."); + "but it was exported in the module '" + importedModuleName + "' with the type " + WasmType.toString(exportedValueType) + "."); } if (exportedMutability != mutability) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE, "Global variable '" + importedGlobalName + "' is imported into module '" + instance.name() + "' with the modifier " + GlobalModifier.asString(mutability) + ", " + - "'but it was exported in the module '" + importedModuleName + "' with the modifier " + GlobalModifier.asString(exportedMutability) + "."); + "but it was exported in the module '" + importedModuleName + "' with the modifier " + GlobalModifier.asString(exportedMutability) + "."); } instance.setExternalGlobal(globalIndex, externalGlobal); instance.globals().setInitialized(globalIndex, true); @@ -881,7 +881,7 @@ void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor // MAX_TABLE_DECLARATION_SIZE, so this condition will pass. assertUnsignedIntLessOrEqual(declaredMinSize, importedTable.minSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); assertUnsignedIntGreaterOrEqual(declaredMaxSize, importedTable.declaredMaxSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); - assertTrue(instance.symbolTable().makeClosedType(elemType).isSupertypeOf(importedTable.closedValueType()), Failure.INCOMPATIBLE_IMPORT_TYPE); + assertTrue(instance.symbolTable().closedTypeOf(elemType).isSupertypeOf(importedTable.closedElemType()), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTableAddress(tableIndex, tableAddress); }; final ImportTableSym importTableSym = new ImportTableSym(instance.name(), importDescriptor); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index f3e3e785c664..dbc060beeb28 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -46,7 +46,6 @@ import static org.graalvm.wasm.WasmMath.maxUnsigned; import static org.graalvm.wasm.WasmMath.minUnsigned; -import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -95,9 +94,7 @@ public abstract static sealed class ClosedValueType { public enum Kind { Number, Vector, - Reference, - Bottom, - Top + Reference } public abstract boolean isSupertypeOf(ClosedValueType valueSubType); @@ -107,11 +104,6 @@ public enum Kind { public abstract boolean matchesValue(Object value); public abstract Kind kind(); - - @Override - public boolean equals(Object other) { - return other instanceof ClosedValueType otherValueType && isSubtypeOf(otherValueType) && isSupertypeOf(otherValueType); - } } public abstract static sealed class ClosedHeapType { @@ -128,11 +120,6 @@ public enum Kind { public abstract boolean matchesValue(Object value); public abstract Kind kind(); - - @Override - public boolean equals(Object other) { - return other instanceof ClosedHeapType otherHeapType && isSubtypeOf(otherHeapType) && isSupertypeOf(otherHeapType); - } } public static final class NumberType extends ClosedValueType { @@ -153,12 +140,12 @@ public int value() { @Override public boolean isSupertypeOf(ClosedValueType valueSubType) { - return valueSubType == BottomType.BOTTOM || valueSubType == this; + return valueSubType == this; } @Override public boolean isSubtypeOf(ClosedValueType valueSuperType) { - return valueSuperType == TopType.TOP || valueSuperType == this; + return valueSuperType == this; } @Override @@ -177,6 +164,11 @@ public Kind kind() { return Kind.Number; } + @Override + public boolean equals(Object that) { + return this == that; + } + @Override public int hashCode() { return value; @@ -198,12 +190,12 @@ public int value() { @Override public boolean isSupertypeOf(ClosedValueType valueSubType) { - return valueSubType == BottomType.BOTTOM || valueSubType == V128; + return valueSubType == V128; } @Override public boolean isSubtypeOf(ClosedValueType valueSuperType) { - return valueSuperType == TopType.TOP || valueSuperType == V128; + return valueSuperType == V128; } @Override @@ -216,6 +208,11 @@ public Kind kind() { return Kind.Vector; } + @Override + public boolean equals(Object that) { + return this == that; + } + @Override public int hashCode() { return value; @@ -248,14 +245,14 @@ public ClosedHeapType heapType() { @Override public boolean isSupertypeOf(ClosedValueType valueSubType) { - return valueSubType == BottomType.BOTTOM || valueSubType instanceof ClosedReferenceType referenceSubType && (!referenceSubType.nullable || this.nullable) && + return valueSubType instanceof ClosedReferenceType referenceSubType && (!referenceSubType.nullable || this.nullable) && this.closedHeapType.isSupertypeOf(referenceSubType.closedHeapType); } @Override public boolean isSubtypeOf(ClosedValueType valueSuperType) { - return valueSuperType == TopType.TOP || valueSuperType instanceof ClosedReferenceType referencedSuperType && (!this.nullable || referencedSuperType.nullable) && - this.closedHeapType.isSubtypeOf(referencedSuperType.closedHeapType); + return valueSuperType instanceof ClosedReferenceType referencedSuperType && (!this.nullable || referencedSuperType.nullable) && + this.closedHeapType.isSubtypeOf(referencedSuperType.closedHeapType); } @Override @@ -268,6 +265,11 @@ public Kind kind() { return Kind.Reference; } + @Override + public boolean equals(Object obj) { + return obj instanceof ClosedReferenceType that && this.nullable == that.nullable && this.closedHeapType.equals(that.closedHeapType); + } + @Override public int hashCode() { return Boolean.hashCode(nullable) ^ closedHeapType.hashCode(); @@ -319,6 +321,11 @@ public Kind kind() { return Kind.Abstract; } + @Override + public boolean equals(Object that) { + return this == that; + } + @Override public int hashCode() { return value; @@ -410,72 +417,13 @@ public Kind kind() { } @Override - public int hashCode() { - return Arrays.hashCode(paramTypes) ^ Arrays.hashCode(resultTypes); - } - } - - public static final class BottomType extends ClosedValueType { - public static final BottomType BOTTOM = new BottomType(); - - private BottomType() { - } - - @Override - public boolean isSupertypeOf(ClosedValueType valueSubType) { - return valueSubType == BOTTOM; - } - - @Override - public boolean isSubtypeOf(ClosedValueType valueSuperType) { - return true; - } - - @Override - public boolean matchesValue(Object value) { - return false; - } - - @Override - public Kind kind() { - return Kind.Bottom; + public boolean equals(Object obj) { + return obj instanceof ClosedFunctionType that && Arrays.equals(this.paramTypes, that.paramTypes) && Arrays.equals(this.resultTypes, that.resultTypes); } @Override public int hashCode() { - return Integer.MIN_VALUE; - } - } - - public static final class TopType extends ClosedValueType { - public static final TopType TOP = new TopType(); - - private TopType() { - } - - @Override - public boolean isSupertypeOf(ClosedValueType valueSubType) { - return true; - } - - @Override - public boolean isSubtypeOf(ClosedValueType valueSuperType) { - return valueSuperType == TOP; - } - - @Override - public boolean matchesValue(Object value) { - return false; - } - - @Override - public Kind kind() { - return Kind.Top; - } - - @Override - public int hashCode() { - return Integer.MAX_VALUE; + return Arrays.hashCode(paramTypes) ^ Arrays.hashCode(resultTypes); } } @@ -543,7 +491,13 @@ public record TagInfo(byte attribute, int typeIndex) { */ @CompilationFinal(dimensions = 1) private int[] typeOffsets; - @CompilationFinal(dimensions = 1) private ClosedValueType[] closedTypes; + /** + * Stores the closed forms of all the types defined in this module. Closed forms replace type + * indices with the definitions of the referenced types, resulting in a tree-like data + * structure. + */ + @CompilationFinal(dimensions = 1) private ClosedHeapType[] closedTypes; + /** * Stores the type equivalence class. *

@@ -760,7 +714,7 @@ public record TagInfo(byte attribute, int typeIndex) { CompilerAsserts.neverPartOfCompilation(); this.typeData = new int[INITIAL_DATA_SIZE]; this.typeOffsets = new int[INITIAL_TYPE_SIZE]; - this.closedTypes = new ClosedValueType[INITIAL_TYPE_SIZE]; + this.closedTypes = new ClosedHeapType[INITIAL_TYPE_SIZE]; this.typeEquivalenceClasses = new int[INITIAL_TYPE_SIZE]; this.typeDataSize = 0; this.typeCount = 0; @@ -818,18 +772,6 @@ public void checkFunctionIndex(int funcIndex) { assertUnsignedIntLess(funcIndex, numFunctions, Failure.UNKNOWN_FUNCTION); } - private static int[] reallocate(int[] array, int currentSize, int newLength) { - int[] newArray = new int[newLength]; - System.arraycopy(array, 0, newArray, 0, currentSize); - return newArray; - } - - private static T[] reallocate(T[] array, int currentSize, int newLength) { - T[] newArray = (T[]) Array.newInstance(array.getClass().getComponentType(), newLength); - System.arraycopy(array, 0, newArray, 0, currentSize); - return newArray; - } - /** * Ensure that the {@link #typeData} array has enough space to store {@code index}. If there is * no enough space, then a reallocation of the array takes place, doubling its capacity. @@ -840,7 +782,7 @@ private static T[] reallocate(T[] array, int currentSize, int newLength) { private void ensureTypeDataCapacity(int index) { if (typeData.length <= index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * typeData.length); - typeData = reallocate(typeData, typeDataSize, newLength); + typeData = Arrays.copyOf(typeData, newLength); } } @@ -855,9 +797,9 @@ private void ensureTypeDataCapacity(int index) { private void ensureTypeCapacity(int index) { if (typeOffsets.length <= index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * typeOffsets.length); - typeOffsets = reallocate(typeOffsets, typeCount, newLength); - closedTypes = reallocate(closedTypes, typeCount, newLength); - typeEquivalenceClasses = reallocate(typeEquivalenceClasses, typeCount, newLength); + typeOffsets = Arrays.copyOf(typeOffsets, newLength); + closedTypes = Arrays.copyOf(closedTypes, newLength); + typeEquivalenceClasses = Arrays.copyOf(typeEquivalenceClasses, newLength); } } @@ -905,7 +847,15 @@ void registerFunctionTypeResultType(int funcTypeIdx, int resultIdx, int type) { } void finishFunctionType(int funcTypeIdx) { - closedTypes[funcTypeIdx] = makeClosedType(funcTypeIdx); + ClosedValueType[] paramTypes = new ClosedValueType[functionTypeParamCount(funcTypeIdx)]; + for (int i = 0; i < paramTypes.length; i++) { + paramTypes[i] = closedTypeOf(functionTypeParamTypeAt(funcTypeIdx, i)); + } + ClosedValueType[] resultTypes = new ClosedValueType[functionTypeResultCount(funcTypeIdx)]; + for (int i = 0; i < resultTypes.length; i++) { + resultTypes[i] = closedTypeOf(functionTypeResultTypeAt(funcTypeIdx, i)); + } + closedTypes[funcTypeIdx] = new ClosedFunctionType(paramTypes, resultTypes); } public int equivalenceClass(int typeIndex) { @@ -923,7 +873,7 @@ void setEquivalenceClass(int index, int eqClass) { private void ensureFunctionsCapacity(int index) { if (functions.length <= index) { int newLength = Math.max(Integer.highestOneBit(index) << 1, 2 * functions.length); - functions = reallocate(functions, numFunctions, newLength); + functions = Arrays.copyOf(functions, newLength); } } @@ -1030,49 +980,49 @@ int typeCount() { return typeCount; } + /** + * Convenience function for calling {@link #closedTypeAt(int)} when the defined type at index + * {@code typeIndex} is known to be a function type. + * + * @see #closedTypeAt(int) + */ public ClosedFunctionType closedFunctionTypeAt(int typeIndex) { - ClosedValueType[] paramTypes = new ClosedValueType[functionTypeParamCount(typeIndex)]; - for (int i = 0; i < paramTypes.length; i++) { - paramTypes[i] = makeClosedType(functionTypeParamTypeAt(typeIndex, i)); - } - ClosedValueType[] resultTypes = new ClosedValueType[functionTypeResultCount(typeIndex)]; - for (int i = 0; i < resultTypes.length; i++) { - resultTypes[i] = makeClosedType(functionTypeResultTypeAt(typeIndex, i)); - } - return new ClosedFunctionType(paramTypes, resultTypes); + return (ClosedFunctionType) closedTypeAt(typeIndex); } - public ClosedValueType closedTypeAt(int typeIndex) { + /** + * Fetches the closed form of a type defined in this module at index {@code typeIndex}. + * + * @param typeIndex index of a type defined in this module + */ + public ClosedHeapType closedTypeAt(int typeIndex) { return closedTypes[typeIndex]; } - public static ClosedValueType closedTypeOfPredefined(int type) { - return switch (type) { - case WasmType.I32_TYPE -> NumberType.I32; - case WasmType.I64_TYPE -> NumberType.I64; - case WasmType.F32_TYPE -> NumberType.F32; - case WasmType.F64_TYPE -> NumberType.F64; - case WasmType.V128_TYPE -> VectorType.V128; - default -> { - if (WasmType.isBottomType(type)) { - yield BottomType.BOTTOM; - } - assert WasmType.isReferenceType(type); - boolean nullable = WasmType.isNullable(type); - yield switch (WasmType.getAbstractHeapType(type)) { - case WasmType.FUNC_HEAPTYPE -> nullable ? ClosedReferenceType.FUNCREF : ClosedReferenceType.NONNULL_FUNCREF; - case WasmType.EXTERN_HEAPTYPE -> nullable ? ClosedReferenceType.EXTERNREF : ClosedReferenceType.NONNULL_EXTERNREF; - case WasmType.EXN_HEAPTYPE -> nullable ? ClosedReferenceType.EXNREF : ClosedReferenceType.NONNULL_EXNREF; - default -> { - assert WasmType.isConcreteReferenceType(type); - throw new IllegalArgumentException(); - } - }; - } - }; + /** + * A convenient way of calling {@link #closedTypeOf(int, SymbolTable)} when a + * {@link SymbolTable} is present. + * + * @see #closedTypeOf(int, SymbolTable) + */ + public ClosedValueType closedTypeOf(int type) { + return SymbolTable.closedTypeOf(type, this); } - public ClosedValueType makeClosedType(int type) { + /** + * Maps a type encoded as an {@code int} (as per {@link WasmType}) into its closed form, + * represented as a {@link ClosedValueType}. Any type indices in the type are resolved using the + * provided symbol table. + *

+ * It is legal to call this function with a null {@code symbolTable}. This is used in cases + * where we need to map a predefined value type to the closed type data representation (i.e. we + * know the type is already closed anyway and so it does not contain any type indices). + *

+ * + * @param type the {@code int}-encoded Wasm type to be expanded + * @param symbolTable used for lookup of type definitions when expanding type indices + */ + public static ClosedValueType closedTypeOf(int type, SymbolTable symbolTable) { return switch (type) { case WasmType.I32_TYPE -> NumberType.I32; case WasmType.I64_TYPE -> NumberType.I64; @@ -1080,9 +1030,6 @@ public ClosedValueType makeClosedType(int type) { case WasmType.F64_TYPE -> NumberType.F64; case WasmType.V128_TYPE -> VectorType.V128; default -> { - if (WasmType.isBottomType(type)) { - yield BottomType.BOTTOM; - } assert WasmType.isReferenceType(type); boolean nullable = WasmType.isNullable(type); yield switch (WasmType.getAbstractHeapType(type)) { @@ -1091,15 +1038,38 @@ yield switch (WasmType.getAbstractHeapType(type)) { case WasmType.EXN_HEAPTYPE -> nullable ? ClosedReferenceType.EXNREF : ClosedReferenceType.NONNULL_EXNREF; default -> { assert WasmType.isConcreteReferenceType(type); - yield new ClosedReferenceType(nullable, closedFunctionTypeAt(WasmType.getTypeIndex(type))); + assert symbolTable != null; + int typeIndex = WasmType.getTypeIndex(type); + ClosedHeapType heapType = symbolTable.closedTypeAt(typeIndex); + yield new ClosedReferenceType(nullable, heapType); } }; } }; } - public boolean matches(int expectedType, int actualType) { - return makeClosedType(expectedType).isSupertypeOf(makeClosedType(actualType)); + /** + * Checks whether the type {@code actualType} matches the type {@code expectedType}. This is the + * case when {@code actualType} is a subtype of {@code expectedType}. + */ + public boolean matchesType(int expectedType, int actualType) { + switch (expectedType) { + case WasmType.BOT -> { + return false; + } + case WasmType.TOP -> { + return true; + } + } + switch (actualType) { + case WasmType.BOT -> { + return true; + } + case WasmType.TOP -> { + return false; + } + } + return closedTypeOf(expectedType).isSupertypeOf(closedTypeOf(actualType)); } public void importSymbol(ImportDescriptor descriptor) { @@ -1296,7 +1266,7 @@ public int globalValueType(int index) { } public ClosedValueType globalClosedValueType(int index) { - return makeClosedType(globalTypes[index]); + return closedTypeOf(globalTypes[index]); } private byte globalFlags(int index) { @@ -1745,7 +1715,7 @@ public void checkElemIndex(int elemIndex) { } public void checkElemType(int elemIndex, int expectedType) { - Assert.assertTrue(matches(expectedType, (int) elemInstances[elemIndex]), Failure.TYPE_MISMATCH); + Assert.assertTrue(matchesType(expectedType, (int) elemInstances[elemIndex]), Failure.TYPE_MISMATCH); } private void ensureElemInstanceCapacity(int index) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java index afde7747e9fe..85eafc7d5d83 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmTable.java @@ -68,7 +68,7 @@ public final class WasmTable extends EmbedderDataHolder implements TruffleObject private final int elemType; /** - * For resolving {@link #elemType} in {@link #closedValueType()}. Can be {@code null} for tables + * For resolving {@link #elemType} in {@link #closedElemType()}. Can be {@code null} for tables * allocated from JS. */ private final SymbolTable symbolTable; @@ -170,8 +170,11 @@ public int elemType() { return elemType; } - public SymbolTable.ClosedValueType closedValueType() { - return symbolTable.makeClosedType(elemType); + /** + * The closed form of the type of the elements in the table. + */ + public SymbolTable.ClosedValueType closedElemType() { + return SymbolTable.closedTypeOf(elemType, symbolTable); } /** diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java index 682e148de552..d31688f20529 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmType.java @@ -198,27 +198,22 @@ yield switch (WasmType.getAbstractHeapType(valueType)) { } public static boolean isNumberType(int type) { - return type == I32_TYPE || type == I64_TYPE || type == F32_TYPE || type == F64_TYPE || isBottomType(type); + return type == I32_TYPE || type == I64_TYPE || type == F32_TYPE || type == F64_TYPE || type == BOT; } public static boolean isVectorType(int type) { - return type == V128_TYPE || isBottomType(type); + return type == V128_TYPE || type == BOT; } public static boolean isReferenceType(int type) { - return isConcreteReferenceType(type) || withNullable(true, type) == FUNC_HEAPTYPE || withNullable(true, type) == EXTERN_HEAPTYPE || withNullable(true, type) == EXN_HEAPTYPE || - isBottomType(type); - } - - public static boolean isBottomType(int type) { - return withNullable(true, type) == BOT; + return isConcreteReferenceType(type) || withNullable(true, type) == FUNC_HEAPTYPE || withNullable(true, type) == EXTERN_HEAPTYPE || withNullable(true, type) == EXN_HEAPTYPE || type == BOT; } /** * Indicates whether this is a user-defined reference type. */ public static boolean isConcreteReferenceType(int type) { - return type >= 0 || isBottomType(type); + return type >= 0 || type == BOT; } /** @@ -258,6 +253,9 @@ public static boolean isNullable(int type) { * type index) */ public static int withNullable(boolean nullable, int type) { + if (type == BOT) { + return BOT; + } return nullable ? type | TYPE_NULLABLE_MASK : type & ~TYPE_NULLABLE_MASK; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java index 67e40d867ad3..cb6277d78f4e 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ExecuteHostFunctionNode.java @@ -128,7 +128,7 @@ public Object execute(VirtualFrame frame) { */ private Object convertResult(Object result, int resultType) throws UnsupportedMessageException { CompilerAsserts.partialEvaluationConstant(resultType); - SymbolTable.ClosedValueType closedResultType = module.closedTypeAt(resultType); + SymbolTable.ClosedValueType closedResultType = module.closedTypeOf(resultType); CompilerAsserts.partialEvaluationConstant(closedResultType); return switch (resultType) { case WasmType.I32_TYPE -> asInt(result); @@ -166,7 +166,7 @@ private void pushMultiValueResult(Object result, int resultCount) { for (int i = 0; i < resultCount; i++) { int resultType = module.symbolTable().functionTypeResultTypeAt(functionTypeIndex, i); CompilerAsserts.partialEvaluationConstant(resultType); - SymbolTable.ClosedValueType closedResultType = module.closedTypeAt(resultType); + SymbolTable.ClosedValueType closedResultType = module.closedTypeOf(resultType); CompilerAsserts.partialEvaluationConstant(closedResultType); Object value = arrayInterop.readArrayElement(result, i); switch (resultType) { @@ -176,7 +176,7 @@ private void pushMultiValueResult(Object result, int resultCount) { case WasmType.F64_TYPE -> primitiveMultiValueStack[i] = Double.doubleToRawLongBits(asDouble(value)); default -> { assert WasmType.isVectorType(resultType) || WasmType.isReferenceType(resultType); - if (!closedResultType.matchesValue(result)) { + if (!closedResultType.matchesValue(value)) { errorBranch.enter(); throw WasmException.create(Failure.INVALID_TYPE_IN_MULTI_VALUE); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java index 2719b169c94b..c5531fd83fe7 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java @@ -183,7 +183,6 @@ yield switch (numberType.value()) { objectMultiValueStack[i] = null; yield obj; } - case Bottom, Top -> throw CompilerDirectives.shouldNotReachHere(); }; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java index e1f3a499e3d4..8bccec71a4ca 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/ValueType.java @@ -133,8 +133,6 @@ yield switch (abstractHeapType.value()) { case Function -> anyfunc; }; } - case Bottom -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: bottom"); - case Top -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, null, "Unknown value type: top"); }; } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java index 9bca8ee57e2f..6db72e510860 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java @@ -509,7 +509,7 @@ private Object tableWrite(Object[] args) { } public Object tableWrite(WasmTable table, int index, Object element) { - if (!table.closedValueType().matchesValue(element)) { + if (!table.closedElemType().matchesValue(element)) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid table element"); } @@ -812,7 +812,7 @@ public Object globalRead(WasmGlobal global) { if (!refTypes) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); } - if (SymbolTable.closedTypeOfPredefined(WasmType.EXNREF_TYPE).isSupertypeOf(global.getClosedValueType())) { + if (SymbolTable.closedTypeOf(WasmType.EXNREF_TYPE, null).isSupertypeOf(global.getClosedValueType())) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.EXNREF_VALUE_ACCESS); } yield global.loadAsReference(); @@ -846,7 +846,7 @@ public Object globalWrite(WasmGlobal global, Object value) { if (!refTypes) { throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); } - if (SymbolTable.closedTypeOfPredefined(WasmType.EXNREF_TYPE).isSupertypeOf(global.getClosedValueType())) { + if (SymbolTable.closedTypeOf(WasmType.EXNREF_TYPE, null).isSupertypeOf(global.getClosedValueType())) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.EXNREF_VALUE_ACCESS); } global.storeReference(value); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java index eaaba8a1369b..8cf40c03638d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java @@ -110,21 +110,7 @@ public int getType() { } public SymbolTable.ClosedValueType getClosedValueType() { - if (symbolTable != null) { - return symbolTable.makeClosedType(getType()); - } else { - // Global was created by WebAssembly#global_alloc - return switch (ValueType.fromValue(getType())) { - case i32 -> SymbolTable.NumberType.I32; - case i64 -> SymbolTable.NumberType.I64; - case f32 -> SymbolTable.NumberType.F32; - case f64 -> SymbolTable.NumberType.F64; - case v128 -> SymbolTable.VectorType.V128; - case anyfunc -> SymbolTable.ClosedReferenceType.FUNCREF; - case externref -> SymbolTable.ClosedReferenceType.EXTERNREF; - case exnref -> SymbolTable.ClosedReferenceType.EXNREF; - }; - } + return SymbolTable.closedTypeOf(getType(), symbolTable); } public boolean isMutable() { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index 4f7b3a9d537e..a527aeab6df0 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -592,7 +592,7 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i // equivalence classes. If they are not equivalent, we run the full subtype // matching procedure. if (expectedTypeEquivalenceClass != function.typeEquivalenceClass() && - !symtab.closedTypeAt(expectedFunctionTypeIndex).isSupertypeOf(new SymbolTable.ClosedReferenceType(false, function.closedType()))) { + !symtab.closedTypeAt(expectedFunctionTypeIndex).isSupertypeOf(function.closedType())) { enterErrorBranch(); failFunctionTypeCheck(function, expectedFunctionTypeIndex); } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java index b061ba167999..1fef69c27988 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/IfFrame.java @@ -94,7 +94,7 @@ void exit(RuntimeBytecodeGen bytecode) { } if (!isUnreachable()) { for (int i = 0; i < resultTypes().length; i++) { - if (!getSymbolTable().matches(resultTypes()[i], paramTypes()[i])) { + if (!getSymbolTable().matchesType(resultTypes()[i], paramTypes()[i])) { throw WasmException.create(Failure.TYPE_MISMATCH, "Expected else branch. If with incompatible param and result types requires else branch."); } } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java index d659e87e4187..63d783cf62c5 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ParserState.java @@ -151,7 +151,7 @@ private boolean isTypeMismatch(int[] expectedTypes, int[] actualTypes) { return false; } for (int i = 0; i < expectedTypes.length; i++) { - if (!symbolTable.matches(expectedTypes[i], actualTypes[i])) { + if (!symbolTable.matchesType(expectedTypes[i], actualTypes[i])) { return true; } } @@ -198,7 +198,7 @@ public int pop() { */ public int popChecked(int expectedValueType) { final int actualValueType = popInternal(expectedValueType); - if (!WasmType.isBottomType(actualValueType) && !WasmType.isBottomType(expectedValueType) && !symbolTable.matches(expectedValueType, actualValueType)) { + if (!symbolTable.matchesType(expectedValueType, actualValueType)) { throw ValidationErrors.createTypeMismatch(expectedValueType, actualValueType); } return actualValueType; @@ -434,7 +434,7 @@ public void addBranchOnNonNull(int branchLabel, int referenceType) { if (labelTypes.length < 1) { throw ValidationErrors.createLabelTypesMismatch(labelTypes, new int[]{referenceType}); } - if (!symbolTable.matches(labelTypes[labelTypes.length - 1], referenceType)) { + if (!symbolTable.matchesType(labelTypes[labelTypes.length - 1], referenceType)) { throw ValidationErrors.createTypeMismatch(labelTypes[labelTypes.length - 1], referenceType); } for (int i = labelTypes.length - 2; i >= 0; i--) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java index 9226b0c7c750..26f9e8d926d5 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/validation/ValidationErrors.java @@ -95,6 +95,11 @@ public static WasmException createMissingLabel(int expected, int max) { return WasmException.format(Failure.UNKNOWN_LABEL, "Unknown branch label %d (max %d).", expected, max); } + @TruffleBoundary + public static WasmException createMissingFunctionType(int expected) { + return WasmException.format(Failure.UNKNOWN_TYPE, "Function type variable %d out of range.", expected); + } + @TruffleBoundary public static WasmException createMissingFunctionType(int expected, int max) { return WasmException.format(Failure.UNKNOWN_TYPE, "Function type variable %d out of range. (max %d)", expected, max); From e16aea98142257088c439496cc4efb97ce1a33e0 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Mon, 20 Oct 2025 01:59:33 +0200 Subject: [PATCH 31/40] Use hexadecimal in opcode definitions consistently --- .../graalvm/wasm/constants/Instructions.java | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java index 2bf625b5fd61..183245e9cdda 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Instructions.java @@ -266,16 +266,16 @@ public final class Instructions { public static final int REF_IS_NULL = 0xD1; public static final int REF_FUNC = 0xD2; - public static final int MEMORY_INIT = 8; - public static final int DATA_DROP = 9; - public static final int MEMORY_COPY = 10; - public static final int MEMORY_FILL = 11; - public static final int TABLE_INIT = 12; - public static final int ELEM_DROP = 13; - public static final int TABLE_COPY = 14; - public static final int TABLE_GROW = 15; - public static final int TABLE_SIZE = 16; - public static final int TABLE_FILL = 17; + public static final int MEMORY_INIT = 0x08; + public static final int DATA_DROP = 0x09; + public static final int MEMORY_COPY = 0x0A; + public static final int MEMORY_FILL = 0x0B; + public static final int TABLE_INIT = 0x0C; + public static final int ELEM_DROP = 0x0D; + public static final int TABLE_COPY = 0x0E; + public static final int TABLE_GROW = 0x0F; + public static final int TABLE_SIZE = 0x10; + public static final int TABLE_FILL = 0x11; public static final int CALL_REF = 0x14; public static final int REF_AS_NON_NULL = 0xD4; From 20b8d833d8675cc2c42171939c198be076b723d1 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Mon, 20 Oct 2025 15:12:17 +0200 Subject: [PATCH 32/40] Stabilize memory footprint benchmark --- .../graalvm/wasm/benchmark/MemoryFootprintBenchmarkRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wasm/src/org.graalvm.wasm.benchmark/src/org/graalvm/wasm/benchmark/MemoryFootprintBenchmarkRunner.java b/wasm/src/org.graalvm.wasm.benchmark/src/org/graalvm/wasm/benchmark/MemoryFootprintBenchmarkRunner.java index 1e22bf48220d..d175735884f0 100644 --- a/wasm/src/org.graalvm.wasm.benchmark/src/org/graalvm/wasm/benchmark/MemoryFootprintBenchmarkRunner.java +++ b/wasm/src/org.graalvm.wasm.benchmark/src/org/graalvm/wasm/benchmark/MemoryFootprintBenchmarkRunner.java @@ -169,7 +169,7 @@ static double getHeapSize() { static void sleep() { try { - Thread.sleep(2000); + Thread.sleep(5000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } From 1c9f8f91169aa0685c45d4111316a667876440ef Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 21 Oct 2025 12:13:21 +0200 Subject: [PATCH 33/40] Check type equivalence for mutable global imports, clean up global imports --- .../src/org/graalvm/wasm/Linker.java | 40 ++++++++----------- .../src/org/graalvm/wasm/SymbolTable.java | 15 +++++-- .../src/org/graalvm/wasm/api/WebAssembly.java | 6 +-- .../org/graalvm/wasm/globals/WasmGlobal.java | 11 ++--- 4 files changed, 36 insertions(+), 36 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 8b4064102567..4125e9237625 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -327,22 +327,14 @@ private static void checkFailures(ArrayList failures) { } } - void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int globalIndex, int valueType, byte mutability, - ImportValueSupplier imports) { + void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescriptor importDescriptor, int globalIndex, int valueType, byte mutability, ImportValueSupplier imports) { instance.globals().setInitialized(globalIndex, false); final String importedGlobalName = importDescriptor.memberName(); final String importedModuleName = importDescriptor.moduleName(); final Runnable resolveAction = () -> { assert instance.module().globalImported(globalIndex) && globalIndex == importDescriptor.targetIndex() : importDescriptor; WasmGlobal externalGlobal = lookupImportObject(instance, importDescriptor, imports, WasmGlobal.class); - final int exportedValueType; - final SymbolTable.ClosedValueType exportedClosedValueType; - final byte exportedMutability; - if (externalGlobal != null) { - exportedValueType = externalGlobal.getType(); - exportedClosedValueType = externalGlobal.getClosedValueType(); - exportedMutability = externalGlobal.getMutability(); - } else { + if (externalGlobal == null) { final WasmInstance importedInstance = store.lookupModuleInstance(importedModuleName); if (importedInstance == null) { throw WasmException.create(Failure.UNKNOWN_IMPORT, "Module '" + importedModuleName + "', referenced in the import of global variable '" + @@ -356,21 +348,22 @@ void resolveGlobalImport(WasmStore store, WasmInstance instance, ImportDescripto "', was not exported in the module '" + importedModuleName + "'."); } - exportedValueType = importedInstance.symbolTable().globalValueType(exportedGlobalIndex); - exportedClosedValueType = importedInstance.symbolTable().globalClosedValueType(exportedGlobalIndex); - exportedMutability = importedInstance.symbolTable().globalMutability(exportedGlobalIndex); - externalGlobal = importedInstance.externalGlobal(exportedGlobalIndex); } - if (!instance.symbolTable().closedTypeOf(valueType).isSupertypeOf(exportedClosedValueType)) { + SymbolTable.ClosedValueType importType = instance.symbolTable().closedTypeOf(valueType); + SymbolTable.ClosedValueType exportType = externalGlobal.getClosedType(); + if (mutability != externalGlobal.getMutability()) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE, "Global variable '" + importedGlobalName + "' is imported into module '" + instance.name() + - "' with the type " + WasmType.toString(valueType) + ", " + - "but it was exported in the module '" + importedModuleName + "' with the type " + WasmType.toString(exportedValueType) + "."); + "' with the modifier " + GlobalModifier.asString(mutability) + ", " + + "but it was exported in the module '" + importedModuleName + "' with the modifier " + GlobalModifier.asString(externalGlobal.getMutability()) + "."); } - if (exportedMutability != mutability) { + // matching for mutable globals does not work by subtyping, but requires equivalent + // types + if (!(externalGlobal.isMutable() ? importType.equals(exportType) : importType.isSupertypeOf(exportType))) { throw WasmException.create(Failure.INCOMPATIBLE_IMPORT_TYPE, "Global variable '" + importedGlobalName + "' is imported into module '" + instance.name() + - "' with the modifier " + GlobalModifier.asString(mutability) + ", " + - "but it was exported in the module '" + importedModuleName + "' with the modifier " + GlobalModifier.asString(exportedMutability) + "."); + "' with the type " + GlobalModifier.asString(mutability) + " " + WasmType.toString(valueType) + ", " + + "but it was exported in the module '" + importedModuleName + "' with the type " + GlobalModifier.asString(externalGlobal.getMutability()) + " " + + WasmType.toString(externalGlobal.getType()) + "."); } instance.setExternalGlobal(globalIndex, externalGlobal); instance.globals().setInitialized(globalIndex, true); @@ -391,7 +384,7 @@ private static void initializeGlobal(WasmInstance instance, int globalIndex, Obj assert !instance.globals().isInitialized(globalIndex) : globalIndex; SymbolTable symbolTable = instance.symbolTable(); if (symbolTable.globalExternal(globalIndex)) { - var global = new WasmGlobal(symbolTable.globalValueType(globalIndex), symbolTable.isGlobalMutable(globalIndex), instance.symbolTable(), initValue); + var global = new WasmGlobal(globalIndex, symbolTable, initValue); instance.setExternalGlobal(globalIndex, global); } else { instance.globals().store(symbolTable.globalValueType(globalIndex), symbolTable.globalAddress(globalIndex), initValue); @@ -568,9 +561,8 @@ void resolveTagImport(WasmStore store, WasmInstance instance, ImportDescriptor i } importedTag = importedInstance.tag(exportedTagIndex); } - // matching for tag types does not work by subtyping, but requires equivalent types, - // A <= B and B <= A - Assert.assertTrue(type.isSupertypeOf(importedTag.type()) && importedTag.type().isSupertypeOf(type), Failure.INCOMPATIBLE_IMPORT_TYPE); + // matching for tag types does not work by subtyping, but requires equivalent types + Assert.assertTrue(type.equals(importedTag.type()), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTag(tagIndex, importedTag); }; resolutionDag.resolveLater(new ImportTagSym(instance.name(), importDescriptor, tagIndex), new Sym[]{new ExportTagSym(importedModuleName, importedTagName)}, resolveAction); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index dbc060beeb28..a70b20b363c9 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -89,6 +89,17 @@ public abstract class SymbolTable { public static final int NO_EQUIVALENCE_CLASS = 0; public static final int FIRST_EQUIVALENCE_CLASS = NO_EQUIVALENCE_CLASS + 1; + /** + * Represents a WebAssembly value type in its closed form, with all type indices replaced with + * their definitions. You can query the subtyping relation on types using the predicates + * {@link #isSupertypeOf(ClosedValueType)} and {@link #isSubtypeOf(ClosedValueType)}, both of + * which are written to be PE-friendly provided the receiver type is a PE constant. + *

+ * If you need to check whether two types are equivalent, instead of checking + * {@code A.isSupertypeOf(B) && A.isSubtypeOf(B)}, you can use {@code A.equals(B)}, since, in + * the WebAssembly type system, type equivalence corresponds to structural equality. + *

+ */ public abstract static sealed class ClosedValueType { // This is a workaround until we can use pattern matching in JDK 21+. public enum Kind { @@ -1265,10 +1276,6 @@ public int globalValueType(int index) { return globalTypes[index]; } - public ClosedValueType globalClosedValueType(int index) { - return closedTypeOf(globalTypes[index]); - } - private byte globalFlags(int index) { return globalFlags[index]; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java index 6db72e510860..b7b4b49fccab 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/WebAssembly.java @@ -812,7 +812,7 @@ public Object globalRead(WasmGlobal global) { if (!refTypes) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); } - if (SymbolTable.closedTypeOf(WasmType.EXNREF_TYPE, null).isSupertypeOf(global.getClosedValueType())) { + if (SymbolTable.closedTypeOf(WasmType.EXNREF_TYPE, null).isSupertypeOf(global.getClosedType())) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.EXNREF_VALUE_ACCESS); } yield global.loadAsReference(); @@ -832,7 +832,7 @@ public Object globalWrite(WasmGlobal global, Object value) { if (!global.isMutable()) { throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global is not mutable."); } - if (!global.getClosedValueType().matchesValue(value)) { + if (!global.getClosedType().matchesValue(value)) { throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Global type %s, value: %s", ValueType.fromValue(global.getType()), value); } switch (global.getType()) { @@ -846,7 +846,7 @@ public Object globalWrite(WasmGlobal global, Object value) { if (!refTypes) { throw WasmJsApiException.format(WasmJsApiException.Kind.TypeError, "Invalid value type. Reference types are not enabled."); } - if (SymbolTable.closedTypeOf(WasmType.EXNREF_TYPE, null).isSupertypeOf(global.getClosedValueType())) { + if (SymbolTable.closedTypeOf(WasmType.EXNREF_TYPE, null).isSupertypeOf(global.getClosedType())) { throw new WasmJsApiException(WasmJsApiException.Kind.TypeError, WasmJsApiException.EXNREF_VALUE_ACCESS); } global.storeReference(value); diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java index 8cf40c03638d..a3ce0b467feb 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/globals/WasmGlobal.java @@ -93,8 +93,9 @@ public static WasmGlobal allocRef(ValueType valueType, boolean mutable, Object v return result; } - public WasmGlobal(int type, boolean mutable, SymbolTable symbolTable, Object value) { - this(type, mutable, symbolTable); + public WasmGlobal(int globalIndex, SymbolTable symbolTable, Object value) { + this(symbolTable.globalValueType(globalIndex), symbolTable.isGlobalMutable(globalIndex), symbolTable); + assert symbolTable.globalExternal(globalIndex); this.globalValue = switch (type) { case WasmType.I32_TYPE -> (int) value; case WasmType.I64_TYPE -> (long) value; @@ -109,7 +110,7 @@ public int getType() { return type; } - public SymbolTable.ClosedValueType getClosedValueType() { + public SymbolTable.ClosedValueType getClosedType() { return SymbolTable.closedTypeOf(getType(), symbolTable); } @@ -225,14 +226,14 @@ void writeMember(String member, Object value, case WasmType.F32_TYPE -> storeInt(Float.floatToRawIntBits(valueLibrary.asFloat(value))); case WasmType.F64_TYPE -> storeLong(Double.doubleToRawLongBits(valueLibrary.asDouble(value))); case WasmType.V128_TYPE -> { - if (!getClosedValueType().matchesValue(value)) { + if (!getClosedType().matchesValue(value)) { throw UnsupportedMessageException.create(); } storeVector128((Vector128) value); } default -> { assert WasmType.isReferenceType(type); - if (!getClosedValueType().matchesValue(value)) { + if (!getClosedType().matchesValue(value)) { throw UnsupportedMessageException.create(); } storeReference(value); From d5ca22305fc51b68420ad85b7dc8a07062818cd2 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 21 Oct 2025 12:40:07 +0200 Subject: [PATCH 34/40] Add toString impl for ClosedValueType hierarchy These string values show up in root node names and ease debugging. --- .../src/org/graalvm/wasm/SymbolTable.java | 61 +++++++++++++++++++ .../wasm/api/InteropCallAdapterNode.java | 1 - 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java index a70b20b363c9..38434d1d2204 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java @@ -184,6 +184,17 @@ public boolean equals(Object that) { public int hashCode() { return value; } + + @Override + public String toString() { + return switch (value()) { + case WasmType.I32_TYPE -> "i32"; + case WasmType.I64_TYPE -> "i64"; + case WasmType.F32_TYPE -> "f32"; + case WasmType.F64_TYPE -> "f64"; + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } } public static final class VectorType extends ClosedValueType { @@ -228,6 +239,11 @@ public boolean equals(Object that) { public int hashCode() { return value; } + + @Override + public String toString() { + return "v128"; + } } public static final class ClosedReferenceType extends ClosedValueType { @@ -285,6 +301,27 @@ public boolean equals(Object obj) { public int hashCode() { return Boolean.hashCode(nullable) ^ closedHeapType.hashCode(); } + + @Override + public String toString() { + CompilerAsserts.neverPartOfCompilation(); + if (this == FUNCREF) { + return "funcref"; + } else if (this == EXTERNREF) { + return "externref"; + } else if (this == EXNREF) { + return "exnref"; + } else { + StringBuilder buf = new StringBuilder(); + buf.append("(ref "); + if (nullable) { + buf.append("null "); + } + buf.append(closedHeapType.toString()); + buf.append(")"); + return buf.toString(); + } + } } public static final class AbstractHeapType extends ClosedHeapType { @@ -341,6 +378,16 @@ public boolean equals(Object that) { public int hashCode() { return value; } + + @Override + public String toString() { + return switch (this.value) { + case WasmType.FUNC_HEAPTYPE -> "func"; + case WasmType.EXTERN_HEAPTYPE -> "extern"; + case WasmType.EXN_HEAPTYPE -> "exn"; + default -> throw CompilerDirectives.shouldNotReachHere(); + }; + } } public static final class ClosedFunctionType extends ClosedHeapType { @@ -436,6 +483,20 @@ public boolean equals(Object obj) { public int hashCode() { return Arrays.hashCode(paramTypes) ^ Arrays.hashCode(resultTypes); } + + @Override + public String toString() { + CompilerAsserts.neverPartOfCompilation(); + String[] paramNames = new String[paramTypes.length]; + for (int i = 0; i < paramTypes.length; i++) { + paramNames[i] = paramTypes[i].toString(); + } + String[] resultNames = new String[resultTypes.length]; + for (int i = 0; i < resultTypes.length; i++) { + resultNames[i] = resultTypes[i].toString(); + } + return "(" + String.join(" ", paramNames) + ")->(" + String.join(" ", resultNames) + ")"; + } } /** diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java index c5531fd83fe7..82aa7f8f9865 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java @@ -186,7 +186,6 @@ yield switch (numberType.value()) { }; } - // TODO: Do we need the 3 overrides below? @Override public String getName() { return "wasm-function-interop:" + functionType; From c278cf7a9074892596f38948647203af28979186 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 21 Oct 2025 17:44:21 +0200 Subject: [PATCH 35/40] Use type equality when checking type of imported tables --- wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 4125e9237625..853089e2980d 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -873,7 +873,9 @@ void resolveTableImport(WasmStore store, WasmInstance instance, ImportDescriptor // MAX_TABLE_DECLARATION_SIZE, so this condition will pass. assertUnsignedIntLessOrEqual(declaredMinSize, importedTable.minSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); assertUnsignedIntGreaterOrEqual(declaredMaxSize, importedTable.declaredMaxSize(), Failure.INCOMPATIBLE_IMPORT_TYPE); - assertTrue(instance.symbolTable().closedTypeOf(elemType).isSupertypeOf(importedTable.closedElemType()), Failure.INCOMPATIBLE_IMPORT_TYPE); + // when matching element types of imported tables, we need to check for type equivalence + // instead of subtyping, as tables have read/write access + assertTrue(instance.symbolTable().closedTypeOf(elemType).equals(importedTable.closedElemType()), Failure.INCOMPATIBLE_IMPORT_TYPE); instance.setTableAddress(tableIndex, tableAddress); }; final ImportTableSym importTableSym = new ImportTableSym(instance.name(), importDescriptor); From a347f3abecf441aec8ca3fb35b9cade16cc69c34 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Tue, 21 Oct 2025 17:57:03 +0200 Subject: [PATCH 36/40] Allow executing functions from Wasm modules that have failed linking This is called for by the official spec test "linking". We keep the check only for modules eval'ed with EvalReturnsInstance, where it is possible for the embedder to obtain a WasmInstance that will fail linking. Without EvalReturnsInstance, instances must pass linking before they are returned to the embedder. --- .../src/org/graalvm/wasm/nodes/WasmRootNode.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmRootNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmRootNode.java index 0d90a6ccf5ef..351c96467a27 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmRootNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmRootNode.java @@ -84,7 +84,7 @@ public final void tryInitialize(WasmInstance instance) { // We want to ensure that linking always precedes the running of the WebAssembly code. // This linking should be as late as possible, because a WebAssembly context should // be able to parse multiple modules before the code gets run. - if (!instance.isLinkCompletedFastPath()) { + if (getContext().getContextOptions().evalReturnsInstance() && !instance.isLinkCompletedFastPath()) { nonLinkedProfile.enter(); instance.store().linker().tryLinkFastPath(instance); } From 2a0f44c1a489e1696c800a95180e445740c0ca4d Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Wed, 22 Oct 2025 00:19:10 +0200 Subject: [PATCH 37/40] Set correct order of elem and data link actions in Instantiator These actions have side effects and their ordering can matter insofar as one link action can fail and only the effects of those that came before should be observable. --- .../src/org/graalvm/wasm/Linker.java | 4 +- .../org/graalvm/wasm/WasmInstantiator.java | 204 +++++++++--------- 2 files changed, 104 insertions(+), 104 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java index 853089e2980d..fdd9aeb35378 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java @@ -1357,7 +1357,7 @@ static class ImportTableSym extends Sym { @Override public String toString() { - return String.format("(import memory %s from %s into %s)", importDescriptor.memberName(), importDescriptor.moduleName(), moduleName); + return String.format("(import table %s from %s into %s)", importDescriptor.memberName(), importDescriptor.moduleName(), moduleName); } @Override @@ -1435,7 +1435,7 @@ static class ElemSym extends Sym { @Override public String toString() { - return String.format(Locale.ROOT, "(data %d in %s)", elemSegmentId, moduleName); + return String.format(Locale.ROOT, "(elem %d in %s)", elemSegmentId, moduleName); } @Override diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java index 2bc19df720c5..c9553ae9fc29 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmInstantiator.java @@ -229,108 +229,6 @@ static List recreateLinkActions(WasmModule module) { final byte[] bytecode = module.bytecode(); - for (int i = 0; i < module.dataInstanceCount(); i++) { - final int dataIndex = i; - final int dataOffset = module.dataInstanceOffset(dataIndex); - final int encoding = bytecode[dataOffset]; - int effectiveOffset = dataOffset + 1; - - final int dataMode = encoding & BytecodeBitEncoding.DATA_SEG_MODE_VALUE; - final int dataLength; - switch (encoding & BytecodeBitEncoding.DATA_SEG_LENGTH_MASK) { - case BytecodeBitEncoding.DATA_SEG_LENGTH_U8: - dataLength = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); - effectiveOffset++; - break; - case BytecodeBitEncoding.DATA_SEG_LENGTH_U16: - dataLength = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); - effectiveOffset += 2; - break; - case BytecodeBitEncoding.DATA_SEG_LENGTH_I32: - dataLength = BinaryStreamParser.rawPeekI32(bytecode, effectiveOffset); - effectiveOffset += 4; - break; - default: - throw CompilerDirectives.shouldNotReachHere(); - } - if (dataMode == SegmentMode.ACTIVE) { - final long value; - switch (encoding & BytecodeBitEncoding.DATA_SEG_VALUE_MASK) { - case BytecodeBitEncoding.DATA_SEG_VALUE_UNDEFINED: - value = -1; - break; - case BytecodeBitEncoding.DATA_SEG_VALUE_U8: - value = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); - effectiveOffset++; - break; - case BytecodeBitEncoding.DATA_SEG_VALUE_U16: - value = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); - effectiveOffset += 2; - break; - case BytecodeBitEncoding.DATA_SEG_VALUE_U32: - value = BinaryStreamParser.rawPeekU32(bytecode, effectiveOffset); - effectiveOffset += 4; - break; - case BytecodeBitEncoding.DATA_SEG_VALUE_I64: - value = BinaryStreamParser.rawPeekI64(bytecode, effectiveOffset); - effectiveOffset += 8; - break; - default: - throw CompilerDirectives.shouldNotReachHere(); - } - final byte[] dataOffsetBytecode; - final long dataOffsetAddress; - if ((encoding & BytecodeBitEncoding.DATA_SEG_BYTECODE_OR_OFFSET_MASK) == BytecodeBitEncoding.DATA_SEG_BYTECODE && - ((encoding & BytecodeBitEncoding.DATA_SEG_VALUE_MASK) != BytecodeBitEncoding.DATA_SEG_VALUE_UNDEFINED)) { - int dataOffsetBytecodeLength = (int) value; - dataOffsetBytecode = Arrays.copyOfRange(bytecode, effectiveOffset, effectiveOffset + dataOffsetBytecodeLength); - effectiveOffset += dataOffsetBytecodeLength; - dataOffsetAddress = -1; - } else { - dataOffsetBytecode = null; - dataOffsetAddress = value; - } - - final int memoryIndex; - if ((encoding & BytecodeBitEncoding.DATA_SEG_HAS_MEMORY_INDEX_ZERO) != 0) { - memoryIndex = 0; - } else { - final int memoryIndexEncoding = bytecode[effectiveOffset]; - effectiveOffset++; - switch (memoryIndexEncoding & BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_MASK) { - case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U6: - memoryIndex = memoryIndexEncoding & BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_VALUE; - break; - case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U8: - memoryIndex = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); - effectiveOffset++; - break; - case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U16: - memoryIndex = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); - effectiveOffset += 2; - break; - case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_I32: - memoryIndex = BinaryStreamParser.rawPeekI32(bytecode, effectiveOffset); - effectiveOffset += 4; - break; - default: - throw CompilerDirectives.shouldNotReachHere(); - } - } - - final int dataBytecodeOffset = effectiveOffset; - linkActions.add((context, store, instance, imports) -> { - store.linker().resolveDataSegment(store, instance, dataIndex, memoryIndex, dataOffsetAddress, dataOffsetBytecode, dataLength, - dataBytecodeOffset, instance.droppedDataInstanceOffset()); - }); - } else { - final int dataBytecodeOffset = effectiveOffset; - linkActions.add((context, store, instance, imports) -> { - store.linker().resolvePassiveDataSegment(store, instance, dataIndex, dataBytecodeOffset); - }); - } - } - for (int i = 0; i < module.elemInstanceCount(); i++) { final int elemIndex = i; final int elemOffset = module.elemInstanceOffset(elemIndex); @@ -450,6 +348,108 @@ static List recreateLinkActions(WasmModule module) { } } + for (int i = 0; i < module.dataInstanceCount(); i++) { + final int dataIndex = i; + final int dataOffset = module.dataInstanceOffset(dataIndex); + final int encoding = bytecode[dataOffset]; + int effectiveOffset = dataOffset + 1; + + final int dataMode = encoding & BytecodeBitEncoding.DATA_SEG_MODE_VALUE; + final int dataLength; + switch (encoding & BytecodeBitEncoding.DATA_SEG_LENGTH_MASK) { + case BytecodeBitEncoding.DATA_SEG_LENGTH_U8: + dataLength = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); + effectiveOffset++; + break; + case BytecodeBitEncoding.DATA_SEG_LENGTH_U16: + dataLength = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); + effectiveOffset += 2; + break; + case BytecodeBitEncoding.DATA_SEG_LENGTH_I32: + dataLength = BinaryStreamParser.rawPeekI32(bytecode, effectiveOffset); + effectiveOffset += 4; + break; + default: + throw CompilerDirectives.shouldNotReachHere(); + } + if (dataMode == SegmentMode.ACTIVE) { + final long value; + switch (encoding & BytecodeBitEncoding.DATA_SEG_VALUE_MASK) { + case BytecodeBitEncoding.DATA_SEG_VALUE_UNDEFINED: + value = -1; + break; + case BytecodeBitEncoding.DATA_SEG_VALUE_U8: + value = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); + effectiveOffset++; + break; + case BytecodeBitEncoding.DATA_SEG_VALUE_U16: + value = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); + effectiveOffset += 2; + break; + case BytecodeBitEncoding.DATA_SEG_VALUE_U32: + value = BinaryStreamParser.rawPeekU32(bytecode, effectiveOffset); + effectiveOffset += 4; + break; + case BytecodeBitEncoding.DATA_SEG_VALUE_I64: + value = BinaryStreamParser.rawPeekI64(bytecode, effectiveOffset); + effectiveOffset += 8; + break; + default: + throw CompilerDirectives.shouldNotReachHere(); + } + final byte[] dataOffsetBytecode; + final long dataOffsetAddress; + if ((encoding & BytecodeBitEncoding.DATA_SEG_BYTECODE_OR_OFFSET_MASK) == BytecodeBitEncoding.DATA_SEG_BYTECODE && + ((encoding & BytecodeBitEncoding.DATA_SEG_VALUE_MASK) != BytecodeBitEncoding.DATA_SEG_VALUE_UNDEFINED)) { + int dataOffsetBytecodeLength = (int) value; + dataOffsetBytecode = Arrays.copyOfRange(bytecode, effectiveOffset, effectiveOffset + dataOffsetBytecodeLength); + effectiveOffset += dataOffsetBytecodeLength; + dataOffsetAddress = -1; + } else { + dataOffsetBytecode = null; + dataOffsetAddress = value; + } + + final int memoryIndex; + if ((encoding & BytecodeBitEncoding.DATA_SEG_HAS_MEMORY_INDEX_ZERO) != 0) { + memoryIndex = 0; + } else { + final int memoryIndexEncoding = bytecode[effectiveOffset]; + effectiveOffset++; + switch (memoryIndexEncoding & BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_MASK) { + case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U6: + memoryIndex = memoryIndexEncoding & BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_VALUE; + break; + case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U8: + memoryIndex = BinaryStreamParser.rawPeekU8(bytecode, effectiveOffset); + effectiveOffset++; + break; + case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_U16: + memoryIndex = BinaryStreamParser.rawPeekU16(bytecode, effectiveOffset); + effectiveOffset += 2; + break; + case BytecodeBitEncoding.DATA_SEG_MEMORY_INDEX_I32: + memoryIndex = BinaryStreamParser.rawPeekI32(bytecode, effectiveOffset); + effectiveOffset += 4; + break; + default: + throw CompilerDirectives.shouldNotReachHere(); + } + } + + final int dataBytecodeOffset = effectiveOffset; + linkActions.add((context, store, instance, imports) -> { + store.linker().resolveDataSegment(store, instance, dataIndex, memoryIndex, dataOffsetAddress, dataOffsetBytecode, dataLength, + dataBytecodeOffset, instance.droppedDataInstanceOffset()); + }); + } else { + final int dataBytecodeOffset = effectiveOffset; + linkActions.add((context, store, instance, imports) -> { + store.linker().resolvePassiveDataSegment(store, instance, dataIndex, dataBytecodeOffset); + }); + } + } + return linkActions; } From 538847ea9736f737e3b3f6cf2186d484b40d01a7 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Sat, 25 Oct 2025 01:12:19 +0200 Subject: [PATCH 38/40] Merge bytecode handlers for call_indirect and call_ref --- .../src/org/graalvm/wasm/BinaryParser.java | 2 + .../org/graalvm/wasm/constants/Bytecode.java | 8 +- .../graalvm/wasm/constants/StackEffects.java | 6 +- .../graalvm/wasm/nodes/WasmFunctionNode.java | 161 ++++++++---------- .../wasm/parser/bytecode/BytecodeParser.java | 24 +-- .../parser/bytecode/RuntimeBytecodeGen.java | 2 - 6 files changed, 89 insertions(+), 114 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java index 08834a1b1303..7790fc7ee553 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/BinaryParser.java @@ -987,6 +987,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn final int elementType = module.tableElementType(index); state.popChecked(I32_TYPE); state.push(elementType); + state.addMiscFlag(); state.addInstruction(Bytecode.TABLE_GET, index); break; } @@ -996,6 +997,7 @@ private CodeEntry readFunction(int functionIndex, int[] locals, int sourceCodeEn final int elementType = module.tableElementType(index); state.popChecked(elementType); state.popChecked(I32_TYPE); + state.addMiscFlag(); state.addInstruction(Bytecode.TABLE_SET, index); break; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java index b0c1fbb67e60..ae0d5d040127 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/Bytecode.java @@ -333,8 +333,8 @@ public class Bytecode { public static final int REF_IS_NULL = 0xF7; public static final int REF_FUNC = 0xF8; - public static final int TABLE_GET = 0xF9; - public static final int TABLE_SET = 0xFA; + public static final int CALL_REF_U8 = 0xF9; + public static final int CALL_REF_I32 = 0xFA; public static final int MISC = 0xFB; @@ -376,8 +376,8 @@ public class Bytecode { public static final int THROW_REF = 0x1C; // Typed function references opcodes - public static final int CALL_REF_U8 = 0x1D; - public static final int CALL_REF_I32 = 0x1E; + public static final int TABLE_GET = 0x1D; + public static final int TABLE_SET = 0x1E; public static final int REF_AS_NON_NULL = 0x1F; public static final int BR_ON_NULL_U8 = 0x20; public static final int BR_ON_NULL_I32 = 0x21; diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/StackEffects.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/StackEffects.java index e14c571832c3..710f3bf543c2 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/StackEffects.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/constants/StackEffects.java @@ -91,10 +91,8 @@ public final class StackEffects { // followed by throw miscOpStackEffects[Bytecode.THROW_REF] = UNREACHABLE; // unused, because stack effect is // followed by throw - miscOpStackEffects[Bytecode.CALL_REF_U8] = UNREACHABLE; // unused, because stack effect is - // variable - miscOpStackEffects[Bytecode.CALL_REF_I32] = UNREACHABLE; // unused, because stack effect is - // variable + miscOpStackEffects[Bytecode.TABLE_GET] = NO_EFFECT; + miscOpStackEffects[Bytecode.TABLE_SET] = POP_2; miscOpStackEffects[Bytecode.REF_AS_NON_NULL] = NO_EFFECT; miscOpStackEffects[Bytecode.BR_ON_NULL_U8] = UNREACHABLE; // unused, because stack effect is // dynamic diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index a527aeab6df0..cc61ce761b15 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -535,9 +535,9 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i break; } case Bytecode.CALL_INDIRECT_U8: - case Bytecode.CALL_INDIRECT_I32: { - // Extract the function object. - stackPointer--; + case Bytecode.CALL_INDIRECT_I32: + case Bytecode.CALL_REF_U8: + case Bytecode.CALL_REF_I32: { final SymbolTable symtab = module.symbolTable(); final int callNodeIndex; @@ -548,53 +548,83 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i expectedFunctionTypeIndex = rawPeekU8(bytecode, offset + 1); tableIndex = rawPeekU8(bytecode, offset + 2); offset += 3; - } else { + } else if (opcode == Bytecode.CALL_INDIRECT_I32) { callNodeIndex = rawPeekI32(bytecode, offset); expectedFunctionTypeIndex = rawPeekI32(bytecode, offset + 4); tableIndex = rawPeekI32(bytecode, offset + 8); offset += 12; + } else if (opcode == Bytecode.CALL_REF_U8) { + callNodeIndex = rawPeekU8(bytecode, offset); + expectedFunctionTypeIndex = rawPeekU8(bytecode, offset + 1); + tableIndex = -1; + offset += 2; + } else { + assert opcode == Bytecode.CALL_REF_I32; + callNodeIndex = rawPeekI32(bytecode, offset); + expectedFunctionTypeIndex = rawPeekI32(bytecode, offset + 4); + tableIndex = -1; + offset += 8; } - final WasmTable table = instance.store().tables().table(instance.tableAddress(tableIndex)); - final Object[] elements = table.elements(); - final int elementIndex = popInt(frame, stackPointer); - if (elementIndex < 0 || elementIndex >= elements.length) { - enterErrorBranch(); - throw WasmException.format(Failure.UNDEFINED_ELEMENT, this, "Element index '%d' out of table bounds.", elementIndex); - } - // Currently, table elements may only be functions. - // We can add a check here when this changes in the future. - final Object element = elements[elementIndex]; - if (element == WasmConstant.NULL) { - enterErrorBranch(); - throw WasmException.format(Failure.UNINITIALIZED_ELEMENT, this, "Table element at index %d is uninitialized.", elementIndex); + + // Extract the function object. + final Object functionCandidate; + final int elementIndex; + if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { + final WasmTable table = instance.store().tables().table(instance.tableAddress(tableIndex)); + final Object[] elements = table.elements(); + elementIndex = popInt(frame, --stackPointer); + if (elementIndex < 0 || elementIndex >= elements.length) { + enterErrorBranch(); + throw WasmException.format(Failure.UNDEFINED_ELEMENT, this, "Element index '%d' out of table bounds.", elementIndex); + } + // Currently, table elements may only be functions. + // We can add a check here when this changes in the future. + functionCandidate = elements[elementIndex]; + } else { + assert opcode == Bytecode.CALL_REF_U8 || opcode == Bytecode.CALL_REF_I32; + functionCandidate = popReference(frame, --stackPointer); + elementIndex = -1; } final WasmFunctionInstance functionInstance; final WasmFunction function; final CallTarget target; final WasmContext functionInstanceContext; - if (element instanceof WasmFunctionInstance) { - functionInstance = (WasmFunctionInstance) element; + if (functionCandidate == WasmConstant.NULL) { + enterErrorBranch(); + if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { + throw WasmException.format(Failure.UNINITIALIZED_ELEMENT, this, "Table element at index %d is uninitialized.", elementIndex); + } else { + assert opcode == Bytecode.CALL_REF_U8 || opcode == Bytecode.CALL_REF_I32; + throw WasmException.format(Failure.NULL_FUNCTION_REFERENCE, this, "Function reference is null"); + } + } else if (functionCandidate instanceof WasmFunctionInstance) { + functionInstance = (WasmFunctionInstance) functionCandidate; function = functionInstance.function(); target = functionInstance.target(); functionInstanceContext = functionInstance.context(); } else { enterErrorBranch(); - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown table element type: %s", element); + if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { + throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown table element type: %s", functionCandidate); + } else { + assert opcode == Bytecode.CALL_REF_U8 || opcode == Bytecode.CALL_REF_I32; + throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown function object: %s", functionCandidate); + } } - int expectedTypeEquivalenceClass = symtab.equivalenceClass(expectedFunctionTypeIndex); - // Target function instance must be from the same context. assert functionInstanceContext == WasmContext.get(this); - // Validate that the target function type matches the expected type of the - // indirect call. We first try if the types are equivalent using the - // equivalence classes. If they are not equivalent, we run the full subtype - // matching procedure. - if (expectedTypeEquivalenceClass != function.typeEquivalenceClass() && - !symtab.closedTypeAt(expectedFunctionTypeIndex).isSupertypeOf(function.closedType())) { - enterErrorBranch(); - failFunctionTypeCheck(function, expectedFunctionTypeIndex); + if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { + // Validate that the target function type matches the expected type of + // the indirect call. We first try if the types are equivalent using the + // equivalence classes. If they are not equivalent, we run the full + // subtype matching procedure. + if (symtab.equivalenceClass(expectedFunctionTypeIndex) != function.typeEquivalenceClass() && + !symtab.closedTypeAt(expectedFunctionTypeIndex).isSupertypeOf(function.closedType())) { + enterErrorBranch(); + failFunctionTypeCheck(function, expectedFunctionTypeIndex); + } } // Invoke the resolved function. @@ -1509,19 +1539,6 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i offset += 4; break; } - case Bytecode.TABLE_GET: { - final int tableIndex = rawPeekI32(bytecode, offset); - table_get(instance, frame, stackPointer, tableIndex); - offset += 4; - break; - } - case Bytecode.TABLE_SET: { - final int tableIndex = rawPeekI32(bytecode, offset); - table_set(instance, frame, stackPointer, tableIndex); - stackPointer -= 2; - offset += 4; - break; - } case Bytecode.MISC: { final int miscOpcode = rawPeekU8(bytecode, offset); offset++; @@ -1548,51 +1565,6 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i assert exception instanceof WasmRuntimeException : "Only wasm exceptions can be thrown by throw_ref"; throw (WasmRuntimeException) exception; } - case Bytecode.CALL_REF_U8: - case Bytecode.CALL_REF_I32: { - final int callNodeIndex; - final int expectedFunctionTypeIndex; - if (miscOpcode == Bytecode.CALL_REF_U8) { - callNodeIndex = rawPeekU8(bytecode, offset); - expectedFunctionTypeIndex = rawPeekU8(bytecode, offset + 1); - offset += 2; - } else { - callNodeIndex = rawPeekI32(bytecode, offset); - expectedFunctionTypeIndex = rawPeekI32(bytecode, offset + 4); - offset += 8; - } - - // Extract the function object. - final WasmFunctionInstance functionInstance; - final CallTarget target; - final WasmContext functionInstanceContext; - final Object functionOrNull = popReference(frame, --stackPointer); - if (functionOrNull == WasmConstant.NULL) { - enterErrorBranch(); - throw WasmException.format(Failure.NULL_FUNCTION_REFERENCE, this, "Function reference is null"); - } else if (functionOrNull instanceof WasmFunctionInstance) { - functionInstance = (WasmFunctionInstance) functionOrNull; - target = functionInstance.target(); - functionInstanceContext = functionInstance.context(); - } else { - enterErrorBranch(); - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown function object: %s", functionOrNull); - } - - // Target function instance must be from the same context. - assert functionInstanceContext == WasmContext.get(this); - - // Invoke the resolved function. - int paramCount = module.symbolTable().functionTypeParamCount(expectedFunctionTypeIndex); - Object[] args = createArgumentsForCall(frame, expectedFunctionTypeIndex, paramCount, stackPointer); - stackPointer -= paramCount; - WasmArguments.setModuleInstance(args, functionInstance.moduleInstance()); - - final Object result = executeIndirectCallNode(callNodeIndex, target, args); - stackPointer = pushIndirectCallResult(frame, stackPointer, expectedFunctionTypeIndex, result, WasmLanguage.get(this)); - CompilerAsserts.partialEvaluationConstant(stackPointer); - break; - } case Bytecode.BR_ON_NULL_U8: { Object reference = popReference(frame, --stackPointer); if (profileCondition(bytecode, offset + 1, reference == WasmConstant.NULL)) { @@ -2271,6 +2243,19 @@ private int executeMisc(WasmInstance instance, VirtualFrame frame, int startingO pushLong(frame, stackPointer - 1, previousSize); break; } + case Bytecode.TABLE_GET: { + final int tableIndex = rawPeekI32(bytecode, offset); + table_get(instance, frame, stackPointer, tableIndex); + offset += 4; + break; + } + case Bytecode.TABLE_SET: { + final int tableIndex = rawPeekI32(bytecode, offset); + table_set(instance, frame, stackPointer, tableIndex); + stackPointer -= 2; + offset += 4; + break; + } case Bytecode.REF_AS_NON_NULL: { Object reference = popReference(frame, stackPointer - 1); if (reference == WasmConstant.NULL) { diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java index 978a5250af39..3845cc7e4d7a 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/BytecodeParser.java @@ -483,12 +483,14 @@ private static List readCallNodes(byte[] bytecode, int startOffset, in offset += 8; break; } - case Bytecode.CALL_INDIRECT_U8: { + case Bytecode.CALL_INDIRECT_U8: + case Bytecode.CALL_REF_U8: { callNodes.add(new CallNode(originalOffset)); offset += 3; break; } - case Bytecode.CALL_INDIRECT_I32: { + case Bytecode.CALL_INDIRECT_I32: + case Bytecode.CALL_REF_I32: { callNodes.add(new CallNode(originalOffset)); offset += 12; break; @@ -722,9 +724,7 @@ private static List readCallNodes(byte[] bytecode, int startOffset, in case Bytecode.I64_STORE_32_I32: case Bytecode.I32_CONST_I32: case Bytecode.F32_CONST: - case Bytecode.REF_FUNC: - case Bytecode.TABLE_GET: - case Bytecode.TABLE_SET: { + case Bytecode.REF_FUNC: { offset += 4; break; } @@ -797,16 +797,6 @@ private static List readCallNodes(byte[] bytecode, int startOffset, in int miscOpcode = rawPeekU8(bytecode, offset); offset++; switch (miscOpcode) { - case Bytecode.CALL_REF_U8: { - callNodes.add(new CallNode(originalOffset)); - offset += 3; - break; - } - case Bytecode.CALL_REF_I32: { - callNodes.add(new CallNode(originalOffset)); - offset += 12; - break; - } case Bytecode.I32_TRUNC_SAT_F32_S: case Bytecode.I32_TRUNC_SAT_F32_U: case Bytecode.I32_TRUNC_SAT_F64_S: @@ -832,7 +822,9 @@ private static List readCallNodes(byte[] bytecode, int startOffset, in case Bytecode.TABLE_GROW: case Bytecode.TABLE_SIZE: case Bytecode.TABLE_FILL: - case Bytecode.THROW: { + case Bytecode.THROW: + case Bytecode.TABLE_GET: + case Bytecode.TABLE_SET: { offset += 4; break; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java index b5212dc19a57..af5b59225302 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/parser/bytecode/RuntimeBytecodeGen.java @@ -574,12 +574,10 @@ public void addIndirectCall(int nodeIndex, int typeIndex, int tableIndex) { */ public void addRefCall(int nodeIndex, int typeIndex) { if (fitsIntoUnsignedByte(nodeIndex) && fitsIntoUnsignedByte(typeIndex)) { - add1(Bytecode.MISC); add1(Bytecode.CALL_REF_U8); add1(nodeIndex); add1(typeIndex); } else { - add1(Bytecode.MISC); add1(Bytecode.CALL_REF_I32); add4(nodeIndex); add4(typeIndex); From 231c7a99ffbcdfad21ccc4860d5473f83d916eca Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Mon, 27 Oct 2025 09:07:23 +0100 Subject: [PATCH 39/40] Profile out subtyping check in call_indirect --- .../src/org/graalvm/wasm/WasmCodeEntry.java | 5 +++++ .../org/graalvm/wasm/nodes/WasmFunctionNode.java | 15 +++++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java index 645352e95e37..4e8636195de9 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmCodeEntry.java @@ -51,6 +51,7 @@ public final class WasmCodeEntry { @CompilationFinal(dimensions = 1) private final int[] resultTypes; private final BranchProfile errorBranch = BranchProfile.create(); private final BranchProfile exceptionBranch = BranchProfile.create(); + private final BranchProfile subtypingBranch = BranchProfile.create(); private final int numLocals; private final int resultCount; private final boolean usesMemoryZero; @@ -101,6 +102,10 @@ public void exceptionBranch() { exceptionBranch.enter(); } + public void subtypingBranch() { + subtypingBranch.enter(); + } + public boolean usesMemoryZero() { return usesMemoryZero; } diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index cc61ce761b15..3c5b79886220 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -618,12 +618,15 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { // Validate that the target function type matches the expected type of // the indirect call. We first try if the types are equivalent using the - // equivalence classes. If they are not equivalent, we run the full - // subtype matching procedure. - if (symtab.equivalenceClass(expectedFunctionTypeIndex) != function.typeEquivalenceClass() && - !symtab.closedTypeAt(expectedFunctionTypeIndex).isSupertypeOf(function.closedType())) { - enterErrorBranch(); - failFunctionTypeCheck(function, expectedFunctionTypeIndex); + // equivalence classes... + if (symtab.equivalenceClass(expectedFunctionTypeIndex) != function.typeEquivalenceClass()) { + codeEntry.subtypingBranch(); + // If they are not equivalent, we run the full subtype matching + // procedure. + if (!symtab.closedTypeAt(expectedFunctionTypeIndex).isSupertypeOf(function.closedType())) { + enterErrorBranch(); + failFunctionTypeCheck(function, expectedFunctionTypeIndex); + } } } From a1962e6251227dbe1572e14a3990e1583871c746 Mon Sep 17 00:00:00 2001 From: Jirka Marsik Date: Mon, 27 Oct 2025 09:21:13 +0100 Subject: [PATCH 40/40] Outline call_indirect/call_ref not-a-function error cases --- .../graalvm/wasm/nodes/WasmFunctionNode.java | 52 ++++++++++--------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index 3c5b79886220..746babb31354 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -585,33 +585,15 @@ public Object executeBodyFromOffset(WasmInstance instance, VirtualFrame frame, i functionCandidate = popReference(frame, --stackPointer); elementIndex = -1; } - final WasmFunctionInstance functionInstance; - final WasmFunction function; - final CallTarget target; - final WasmContext functionInstanceContext; - if (functionCandidate == WasmConstant.NULL) { - enterErrorBranch(); - if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { - throw WasmException.format(Failure.UNINITIALIZED_ELEMENT, this, "Table element at index %d is uninitialized.", elementIndex); - } else { - assert opcode == Bytecode.CALL_REF_U8 || opcode == Bytecode.CALL_REF_I32; - throw WasmException.format(Failure.NULL_FUNCTION_REFERENCE, this, "Function reference is null"); - } - } else if (functionCandidate instanceof WasmFunctionInstance) { - functionInstance = (WasmFunctionInstance) functionCandidate; - function = functionInstance.function(); - target = functionInstance.target(); - functionInstanceContext = functionInstance.context(); - } else { - enterErrorBranch(); - if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown table element type: %s", functionCandidate); - } else { - assert opcode == Bytecode.CALL_REF_U8 || opcode == Bytecode.CALL_REF_I32; - throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown function object: %s", functionCandidate); - } + + if (!(functionCandidate instanceof WasmFunctionInstance functionInstance)) { + throw callIndirectNotAFunctionError(opcode, functionCandidate, elementIndex); } + final WasmFunction function = functionInstance.function(); + final CallTarget target = functionInstance.target(); + final WasmContext functionInstanceContext = functionInstance.context(); + // Target function instance must be from the same context. assert functionInstanceContext == WasmContext.get(this); @@ -1778,6 +1760,26 @@ private void failFunctionTypeCheck(WasmFunction function, int expectedFunctionTy function.typeIndex(), function.name(), expectedFunctionTypeIndex, module.name()); } + @HostCompilerDirectives.InliningCutoff + private WasmException callIndirectNotAFunctionError(int opcode, Object functionCandidate, int elementIndex) { + enterErrorBranch(); + if (functionCandidate == WasmConstant.NULL) { + if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { + throw WasmException.format(Failure.UNINITIALIZED_ELEMENT, this, "Table element at index %d is uninitialized.", elementIndex); + } else { + assert opcode == Bytecode.CALL_REF_U8 || opcode == Bytecode.CALL_REF_I32; + throw WasmException.format(Failure.NULL_FUNCTION_REFERENCE, this, "Function reference is null"); + } + } else { + if (opcode == Bytecode.CALL_INDIRECT_U8 || opcode == Bytecode.CALL_INDIRECT_I32) { + throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown table element type: %s", functionCandidate); + } else { + assert opcode == Bytecode.CALL_REF_U8 || opcode == Bytecode.CALL_REF_I32; + throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown function object: %s", functionCandidate); + } + } + } + private void check(int v, int limit) { // This is a temporary hack to hoist values out of the loop. if (v >= limit) {