From 393a805741c2528a9eb46c457d2e8f939a7fb2b3 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Fri, 29 Mar 2024 13:54:28 +0100 Subject: [PATCH 1/5] spirv: deduplicate prototype --- src/link/SpirV.zig | 2 + src/link/SpirV/BinaryModule.zig | 2 + src/link/SpirV/deduplicate.zig | 400 ++++++++++++++++++++++++++++++++ 3 files changed, 404 insertions(+) create mode 100644 src/link/SpirV/deduplicate.zig diff --git a/src/link/SpirV.zig b/src/link/SpirV.zig index dc25ac510579..728db2d8484a 100644 --- a/src/link/SpirV.zig +++ b/src/link/SpirV.zig @@ -261,6 +261,7 @@ fn linkModule(self: *SpirV, a: Allocator, module: []Word) ![]Word { const lower_invocation_globals = @import("SpirV/lower_invocation_globals.zig"); const prune_unused = @import("SpirV/prune_unused.zig"); + const dedup = @import("SpirV/deduplicate.zig"); var parser = try BinaryModule.Parser.init(a); defer parser.deinit(); @@ -268,6 +269,7 @@ fn linkModule(self: *SpirV, a: Allocator, module: []Word) ![]Word { try lower_invocation_globals.run(&parser, &binary); try prune_unused.run(&parser, &binary); + try dedup.run(&parser, &binary); return binary.finalize(a); } diff --git a/src/link/SpirV/BinaryModule.zig b/src/link/SpirV/BinaryModule.zig index 0c9c32c98e04..e150890315bf 100644 --- a/src/link/SpirV/BinaryModule.zig +++ b/src/link/SpirV/BinaryModule.zig @@ -94,6 +94,8 @@ pub const ParseError = error{ DuplicateId, /// Some ID did not resolve. InvalidId, + /// This opcode or instruction is not supported yet. + UnsupportedOperation, /// Parser ran out of memory. OutOfMemory, }; diff --git a/src/link/SpirV/deduplicate.zig b/src/link/SpirV/deduplicate.zig new file mode 100644 index 000000000000..be067e03e4e1 --- /dev/null +++ b/src/link/SpirV/deduplicate.zig @@ -0,0 +1,400 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const log = std.log.scoped(.spirv_link); +const assert = std.debug.assert; + +const BinaryModule = @import("BinaryModule.zig"); +const Section = @import("../../codegen/spirv/Section.zig"); +const spec = @import("../../codegen/spirv/spec.zig"); +const Opcode = spec.Opcode; +const ResultId = spec.IdResult; +const Word = spec.Word; + +fn canDeduplicate(opcode: Opcode) bool { + return switch (opcode) { + .OpTypeForwardPointer => false, // Don't need to handle these + .OpGroupDecorate, .OpGroupMemberDecorate => { + // These are deprecated, so don't bother supporting them for now. + return false; + }, + .OpName, .OpMemberName => true, // Debug decoration-style instructions + else => switch (opcode.class()) { + .TypeDeclaration, + .ConstantCreation, + .Annotation, + => true, + else => false, + }, + }; +} + +const ModuleInfo = struct { + /// This models a type, decoration or constant instruction + /// and its dependencies. + const Entity = struct { + /// The type that this entity represents. This is just + /// the instruction opcode. + kind: Opcode, + /// Offset of first child result-id, stored in entity_children. + /// These are the shallow entities appearing directly in the + /// type's instruction. + first_child: u32, + /// Offset to the first word of extra-data: Data in the instruction + /// that must be considered for uniqueness, but doesn't include + /// any IDs. + first_extra_data: u32, + }; + + /// Maps result-id to Entity's + entities: std.AutoArrayHashMapUnmanaged(ResultId, Entity), + /// The list of children per instruction. + entity_children: []const ResultId, + /// The list of extra data per instruction. + /// TODO: This is a bit awkward, maybe we need to store it some + /// other way? + extra_data: []const u32, + + pub fn parse( + arena: Allocator, + parser: *BinaryModule.Parser, + binary: BinaryModule, + ) !ModuleInfo { + var entities = std.AutoArrayHashMap(ResultId, Entity).init(arena); + var entity_children = std.ArrayList(ResultId).init(arena); + var extra_data = std.ArrayList(u32).init(arena); + var id_offsets = std.ArrayList(u16).init(arena); + + var it = binary.iterateInstructions(); + while (it.next()) |inst| { + if (inst.opcode == .OpFunction) break; // No more declarations are possible + if (!canDeduplicate(inst.opcode)) continue; + + id_offsets.items.len = 0; + try parser.parseInstructionResultIds(binary, inst, &id_offsets); + + const result_id_index: u32 = switch (inst.opcode.class()) { + .TypeDeclaration, .Annotation, .Debug => 0, + .ConstantCreation => 1, + else => unreachable, + }; + + const result_id: ResultId = @enumFromInt(inst.operands[id_offsets.items[result_id_index]]); + + const first_child: u32 = @intCast(entity_children.items.len); + const first_extra_data: u32 = @intCast(extra_data.items.len); + + try entity_children.ensureUnusedCapacity(id_offsets.items.len - 1); + try extra_data.ensureUnusedCapacity(inst.operands.len - id_offsets.items.len); + + var id_i: usize = 0; + for (inst.operands, 0..) |operand, i| { + assert(id_i == id_offsets.items.len or id_offsets.items[id_i] >= i); + if (id_i != id_offsets.items.len and id_offsets.items[id_i] == i) { + // Skip .IdResult / .IdResultType. + if (id_i != result_id_index) { + entity_children.appendAssumeCapacity(@enumFromInt(operand)); + } + id_i += 1; + } else { + // Non-id operand, add it to extra data. + extra_data.appendAssumeCapacity(operand); + } + } + + switch (inst.opcode.class()) { + .Annotation, .Debug => { + // TODO + }, + .TypeDeclaration, .ConstantCreation => { + const entry = try entities.getOrPut(result_id); + if (entry.found_existing) { + log.err("type or constant {} has duplicate definition", .{result_id}); + return error.DuplicateId; + } + entry.value_ptr.* = .{ + .kind = inst.opcode, + .first_child = first_child, + .first_extra_data = first_extra_data, + }; + }, + else => unreachable, + } + } + + return ModuleInfo{ + .entities = entities.unmanaged, + .entity_children = entity_children.items, + .extra_data = extra_data.items, + }; + } + + /// Fetch a slice of children for the index corresponding to an entity. + fn childrenByIndex(self: ModuleInfo, index: usize) []const ResultId { + const values = self.entities.values(); + const first_child = values[index].first_child; + if (index == values.len - 1) { + return self.entity_children[first_child..]; + } else { + const next_first_child = values[index + 1].first_child; + return self.entity_children[first_child..next_first_child]; + } + } + + /// Fetch the slice of extra-data for the index corresponding to an entity. + fn extraDataByIndex(self: ModuleInfo, index: usize) []const u32 { + const values = self.entities.values(); + const first_extra_data = values[index].first_extra_data; + if (index == values.len - 1) { + return self.extra_data[first_extra_data..]; + } else { + const next_extra_data = values[index + 1].first_extra_data; + return self.extra_data[first_extra_data..next_extra_data]; + } + } +}; + +const EntityContext = struct { + a: Allocator, + ptr_map_a: std.AutoArrayHashMapUnmanaged(ResultId, void) = .{}, + ptr_map_b: std.AutoArrayHashMapUnmanaged(ResultId, void) = .{}, + info: *const ModuleInfo, + + fn init(a: Allocator, info: *const ModuleInfo) EntityContext { + return .{ + .a = a, + .info = info, + }; + } + + fn deinit(self: *EntityContext) void { + self.ptr_map_a.deinit(self.a); + self.ptr_map_b.deinit(self.a); + + self.* = undefined; + } + + fn equalizeMapCapacity(self: *EntityContext) !void { + const cap = @max(self.ptr_map_a.capacity(), self.ptr_map_b.capacity()); + try self.ptr_map_a.ensureTotalCapacity(self.a, cap); + try self.ptr_map_b.ensureTotalCapacity(self.a, cap); + } + + fn hash(self: *EntityContext, id: ResultId) !u64 { + var hasher = std.hash.Wyhash.init(0); + self.ptr_map_a.clearRetainingCapacity(); + try self.hashInner(&hasher, id); + return hasher.final(); + } + + fn hashInner(self: *EntityContext, hasher: *std.hash.Wyhash, id: ResultId) !void { + const index = self.info.entities.getIndex(id).?; + const entity = self.info.entities.values()[index]; + + std.hash.autoHash(hasher, entity.kind); + if (entity.kind == .OpTypePointer) { + // This may be either a pointer that is forward-referenced in the future, + // or a forward reference to a pointer. + const entry = try self.ptr_map_a.getOrPut(self.a, id); + if (entry.found_existing) { + // Pointer already seen. Hash the index instead of recursing into its children. + // TODO: Discriminate this path somehow? + std.hash.autoHash(hasher, entry.index); + return; + } + } + + // Hash extra data + for (self.info.extraDataByIndex(index)) |data| { + std.hash.autoHash(hasher, data); + } + + // Hash children + for (self.info.childrenByIndex(index)) |child| { + try self.hashInner(hasher, child); + } + } + + fn eql(self: *EntityContext, a: ResultId, b: ResultId) !bool { + self.ptr_map_a.clearRetainingCapacity(); + self.ptr_map_b.clearRetainingCapacity(); + + return try self.eqlInner(a, b); + } + + fn eqlInner(self: *EntityContext, id_a: ResultId, id_b: ResultId) !bool { + const index_a = self.info.entities.getIndex(id_a).?; + const index_b = self.info.entities.getIndex(id_b).?; + + const entity_a = self.info.entities.values()[index_a]; + const entity_b = self.info.entities.values()[index_b]; + + if (entity_a.kind != entity_b.kind) return false; + + if (entity_a.kind == .OpTypePointer) { + // May be a forward reference, or should be saved as a potential + // forward reference in the future. Whatever the case, it should + // be the same for both a and b. + const entry_a = try self.ptr_map_a.getOrPut(self.a, id_a); + const entry_b = try self.ptr_map_b.getOrPut(self.a, id_b); + + if (entry_a.found_existing != entry_b.found_existing) return false; + if (entry_a.index != entry_b.index) return false; + + if (entry_a.found_existing) { + // No need to recurse. + return true; + } + } + + // Check if extra data is the same. + if (!std.mem.eql(u32, self.info.extraDataByIndex(index_a), self.info.extraDataByIndex(index_b))) { + return false; + } + + // Recursively check if children are the same + const children_a = self.info.childrenByIndex(index_a); + const children_b = self.info.childrenByIndex(index_b); + if (children_a.len != children_b.len) return false; + + for (children_a, children_b) |child_a, child_b| { + if (!try self.eqlInner(child_a, child_b)) { + return false; + } + } + + return true; + } +}; + +/// This struct is a wrapper around EntityContext that adapts it for +/// use in a hash map. Because EntityContext allocates, it cannot be +/// used. This wrapper simply assumes that the maps have been allocated +/// the max amount of memory they are going to use. +/// This is done by pre-hashing all keys. +const EntityHashContext = struct { + entity_context: *EntityContext, + + pub fn hash(self: EntityHashContext, key: ResultId) u64 { + return self.entity_context.hash(key) catch unreachable; + } + + pub fn eql(self: EntityHashContext, a: ResultId, b: ResultId) bool { + return self.entity_context.eql(a, b) catch unreachable; + } +}; + +pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void { + var arena = std.heap.ArenaAllocator.init(parser.a); + defer arena.deinit(); + const a = arena.allocator(); + + const info = try ModuleInfo.parse(a, parser, binary.*); + log.info("added {} entities", .{info.entities.count()}); + log.info("children size: {}", .{info.entity_children.len}); + log.info("extra data size: {}", .{info.extra_data.len}); + + // Hash all keys once so that the maps can be allocated the right size. + var ctx = EntityContext.init(a, &info); + for (info.entities.keys()) |id| { + _ = try ctx.hash(id); + } + + // hash only uses ptr_map_a, so allocate ptr_map_b too + try ctx.equalizeMapCapacity(); + + // Figure out which entities can be deduplicated. + var map = std.HashMap(ResultId, void, EntityHashContext, 80).initContext(a, .{ + .entity_context = &ctx, + }); + var replace = std.AutoArrayHashMap(ResultId, ResultId).init(a); + for (info.entities.keys(), info.entities.values()) |id, entity| { + const entry = try map.getOrPut(id); + if (entry.found_existing) { + log.info("deduplicating {} - {s} (prior definition: {})", .{ id, @tagName(entity.kind), entry.key_ptr.* }); + try replace.putNoClobber(id, entry.key_ptr.*); + } + } + + // Now process the module, and replace instructions where needed. + var section = Section{}; + var it = binary.iterateInstructions(); + var id_offsets = std.ArrayList(u16).init(a); + var new_functions_section: ?usize = null; + var new_operands = std.ArrayList(u32).init(a); + var emitted_ptrs = std.AutoHashMap(ResultId, void).init(a); + while (it.next()) |inst| { + // Result-id can only be the first or second operand + const inst_spec = parser.getInstSpec(inst.opcode).?; + const maybe_result_id: ?ResultId = for (0..2) |i| { + if (inst_spec.operands.len > i and inst_spec.operands[i].kind == .IdResult) { + break @enumFromInt(inst.operands[i]); + } + } else null; + + if (maybe_result_id) |result_id| { + if (replace.contains(result_id)) continue; + } + + switch (inst.opcode) { + .OpFunction => if (new_functions_section == null) { + new_functions_section = section.instructions.items.len; + }, + .OpTypeForwardPointer => continue, // We re-emit these where needed + // TODO: These aren't supported yet, strip them out for testing purposes. + .OpName, .OpMemberName => continue, + else => {}, + } + + // Re-emit the instruction, but replace all the IDs. + + id_offsets.items.len = 0; + try parser.parseInstructionResultIds(binary.*, inst, &id_offsets); + + new_operands.items.len = 0; + try new_operands.appendSlice(inst.operands); + for (id_offsets.items) |offset| { + { + const id: ResultId = @enumFromInt(inst.operands[offset]); + if (replace.get(id)) |new_id| { + new_operands.items[offset] = @intFromEnum(new_id); + } + } + + // TODO: Does this logic work? Maybe it will emit an OpTypeForwardPointer to + // something thats not a struct... + // It seems to work correctly on behavior.zig at least + const id: ResultId = @enumFromInt(new_operands.items[offset]); + if (maybe_result_id == null or maybe_result_id.? != id) { + const index = info.entities.getIndex(id) orelse continue; + const entity = info.entities.values()[index]; + if (entity.kind == .OpTypePointer) { + if (!emitted_ptrs.contains(id)) { + // The storage class is in the extra data + // TODO: This is kind of hacky... + const extra_data = info.extraDataByIndex(index); + const storage_class: spec.StorageClass = @enumFromInt(extra_data[0]); + try section.emit(a, .OpTypeForwardPointer, .{ + .pointer_type = id, + .storage_class = storage_class, + }); + try emitted_ptrs.put(id, {}); + } + } + } + } + + if (inst.opcode == .OpTypePointer) { + try emitted_ptrs.put(maybe_result_id.?, {}); + } + + try section.emitRawInstruction(a, inst.opcode, new_operands.items); + } + + for (replace.keys()) |key| { + _ = binary.ext_inst_map.remove(key); + _ = binary.arith_type_width.remove(key); + } + + binary.instructions = try parser.a.dupe(Word, section.toWords()); + binary.sections.functions = new_functions_section orelse binary.instructions.len; +} From b4960394efa71a8246b10b46165292f1797aaf87 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Fri, 29 Mar 2024 23:39:30 +0100 Subject: [PATCH 2/5] spirv: avoid copying operands in dedup pass --- src/link/SpirV/deduplicate.zig | 211 ++++++++++++++------------------- 1 file changed, 88 insertions(+), 123 deletions(-) diff --git a/src/link/SpirV/deduplicate.zig b/src/link/SpirV/deduplicate.zig index be067e03e4e1..fafc608e957b 100644 --- a/src/link/SpirV/deduplicate.zig +++ b/src/link/SpirV/deduplicate.zig @@ -35,24 +35,24 @@ const ModuleInfo = struct { /// The type that this entity represents. This is just /// the instruction opcode. kind: Opcode, - /// Offset of first child result-id, stored in entity_children. - /// These are the shallow entities appearing directly in the - /// type's instruction. - first_child: u32, - /// Offset to the first word of extra-data: Data in the instruction - /// that must be considered for uniqueness, but doesn't include - /// any IDs. - first_extra_data: u32, + /// The offset of this entity's operands, in + /// `binary.instructions`. + first_operand: u32, + /// The number of operands in this entity + num_operands: u16, + /// The (first_operand-relative) offset of the result-id, + /// or the entity that is affected by this entity if this entity + /// is a decoration. + result_id_index: u16, }; /// Maps result-id to Entity's entities: std.AutoArrayHashMapUnmanaged(ResultId, Entity), - /// The list of children per instruction. - entity_children: []const ResultId, - /// The list of extra data per instruction. - /// TODO: This is a bit awkward, maybe we need to store it some - /// other way? - extra_data: []const u32, + /// A bit set that keeps track of which operands are result-ids. + /// Note: This also includes any result-id! + /// Because we need these values when recoding the module anyway, + /// it contains the status of ALL operands in the module. + operand_is_id: std.DynamicBitSetUnmanaged, pub fn parse( arena: Allocator, @@ -60,19 +60,22 @@ const ModuleInfo = struct { binary: BinaryModule, ) !ModuleInfo { var entities = std.AutoArrayHashMap(ResultId, Entity).init(arena); - var entity_children = std.ArrayList(ResultId).init(arena); - var extra_data = std.ArrayList(u32).init(arena); var id_offsets = std.ArrayList(u16).init(arena); + var operand_is_id = try std.DynamicBitSetUnmanaged.initEmpty(arena, binary.instructions.len); var it = binary.iterateInstructions(); while (it.next()) |inst| { - if (inst.opcode == .OpFunction) break; // No more declarations are possible - if (!canDeduplicate(inst.opcode)) continue; - id_offsets.items.len = 0; try parser.parseInstructionResultIds(binary, inst, &id_offsets); - const result_id_index: u32 = switch (inst.opcode.class()) { + const first_operand_offset: u32 = @intCast(inst.offset + 1); + for (id_offsets.items) |offset| { + operand_is_id.set(first_operand_offset + offset); + } + + if (!canDeduplicate(inst.opcode)) continue; + + const result_id_index: u16 = switch (inst.opcode.class()) { .TypeDeclaration, .Annotation, .Debug => 0, .ConstantCreation => 1, else => unreachable, @@ -80,27 +83,6 @@ const ModuleInfo = struct { const result_id: ResultId = @enumFromInt(inst.operands[id_offsets.items[result_id_index]]); - const first_child: u32 = @intCast(entity_children.items.len); - const first_extra_data: u32 = @intCast(extra_data.items.len); - - try entity_children.ensureUnusedCapacity(id_offsets.items.len - 1); - try extra_data.ensureUnusedCapacity(inst.operands.len - id_offsets.items.len); - - var id_i: usize = 0; - for (inst.operands, 0..) |operand, i| { - assert(id_i == id_offsets.items.len or id_offsets.items[id_i] >= i); - if (id_i != id_offsets.items.len and id_offsets.items[id_i] == i) { - // Skip .IdResult / .IdResultType. - if (id_i != result_id_index) { - entity_children.appendAssumeCapacity(@enumFromInt(operand)); - } - id_i += 1; - } else { - // Non-id operand, add it to extra data. - extra_data.appendAssumeCapacity(operand); - } - } - switch (inst.opcode.class()) { .Annotation, .Debug => { // TODO @@ -113,8 +95,9 @@ const ModuleInfo = struct { } entry.value_ptr.* = .{ .kind = inst.opcode, - .first_child = first_child, - .first_extra_data = first_extra_data, + .first_operand = first_operand_offset, + .num_operands = @intCast(inst.operands.len), + .result_id_index = result_id_index, }; }, else => unreachable, @@ -123,34 +106,9 @@ const ModuleInfo = struct { return ModuleInfo{ .entities = entities.unmanaged, - .entity_children = entity_children.items, - .extra_data = extra_data.items, + .operand_is_id = operand_is_id, }; } - - /// Fetch a slice of children for the index corresponding to an entity. - fn childrenByIndex(self: ModuleInfo, index: usize) []const ResultId { - const values = self.entities.values(); - const first_child = values[index].first_child; - if (index == values.len - 1) { - return self.entity_children[first_child..]; - } else { - const next_first_child = values[index + 1].first_child; - return self.entity_children[first_child..next_first_child]; - } - } - - /// Fetch the slice of extra-data for the index corresponding to an entity. - fn extraDataByIndex(self: ModuleInfo, index: usize) []const u32 { - const values = self.entities.values(); - const first_extra_data = values[index].first_extra_data; - if (index == values.len - 1) { - return self.extra_data[first_extra_data..]; - } else { - const next_extra_data = values[index + 1].first_extra_data; - return self.extra_data[first_extra_data..next_extra_data]; - } - } }; const EntityContext = struct { @@ -158,13 +116,7 @@ const EntityContext = struct { ptr_map_a: std.AutoArrayHashMapUnmanaged(ResultId, void) = .{}, ptr_map_b: std.AutoArrayHashMapUnmanaged(ResultId, void) = .{}, info: *const ModuleInfo, - - fn init(a: Allocator, info: *const ModuleInfo) EntityContext { - return .{ - .a = a, - .info = info, - }; - } + binary: *const BinaryModule, fn deinit(self: *EntityContext) void { self.ptr_map_a.deinit(self.a); @@ -203,14 +155,19 @@ const EntityContext = struct { } } - // Hash extra data - for (self.info.extraDataByIndex(index)) |data| { - std.hash.autoHash(hasher, data); - } - - // Hash children - for (self.info.childrenByIndex(index)) |child| { - try self.hashInner(hasher, child); + // Process operands + const operands = self.binary.instructions[entity.first_operand..][0..entity.num_operands]; + for (operands, 0..) |operand, i| { + if (i == entity.result_id_index) { + // Not relevant, skip... + continue; + } else if (self.info.operand_is_id.isSet(entity.first_operand + i)) { + // Operand is ID + try self.hashInner(hasher, @enumFromInt(operand)); + } else { + // Operand is merely data + std.hash.autoHash(hasher, operand); + } } } @@ -228,7 +185,11 @@ const EntityContext = struct { const entity_a = self.info.entities.values()[index_a]; const entity_b = self.info.entities.values()[index_b]; - if (entity_a.kind != entity_b.kind) return false; + if (entity_a.kind != entity_b.kind) { + return false; + } else if (entity_a.result_id_index != entity_a.result_id_index) { + return false; + } if (entity_a.kind == .OpTypePointer) { // May be a forward reference, or should be saved as a potential @@ -246,18 +207,28 @@ const EntityContext = struct { } } - // Check if extra data is the same. - if (!std.mem.eql(u32, self.info.extraDataByIndex(index_a), self.info.extraDataByIndex(index_b))) { + const operands_a = self.binary.instructions[entity_a.first_operand..][0..entity_a.num_operands]; + const operands_b = self.binary.instructions[entity_b.first_operand..][0..entity_b.num_operands]; + + // Note: returns false for operands that have explicit defaults in optional operands... oh well + if (operands_a.len != operands_b.len) { return false; } - // Recursively check if children are the same - const children_a = self.info.childrenByIndex(index_a); - const children_b = self.info.childrenByIndex(index_b); - if (children_a.len != children_b.len) return false; - - for (children_a, children_b) |child_a, child_b| { - if (!try self.eqlInner(child_a, child_b)) { + for (operands_a, operands_b, 0..) |operand_a, operand_b, i| { + const a_is_id = self.info.operand_is_id.isSet(entity_a.first_operand + i); + const b_is_id = self.info.operand_is_id.isSet(entity_b.first_operand + i); + if (a_is_id != b_is_id) { + return false; + } else if (i == entity_a.result_id_index) { + // result-id for both... + continue; + } else if (a_is_id) { + // Both are IDs, so recurse. + if (!try self.eqlInner(@enumFromInt(operand_a), @enumFromInt(operand_b))) { + return false; + } + } else if (operand_a != operand_b) { return false; } } @@ -290,11 +261,13 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void { const info = try ModuleInfo.parse(a, parser, binary.*); log.info("added {} entities", .{info.entities.count()}); - log.info("children size: {}", .{info.entity_children.len}); - log.info("extra data size: {}", .{info.extra_data.len}); // Hash all keys once so that the maps can be allocated the right size. - var ctx = EntityContext.init(a, &info); + var ctx = EntityContext{ + .a = a, + .info = &info, + .binary = binary, + }; for (info.entities.keys()) |id| { _ = try ctx.hash(id); } @@ -318,7 +291,6 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void { // Now process the module, and replace instructions where needed. var section = Section{}; var it = binary.iterateInstructions(); - var id_offsets = std.ArrayList(u16).init(a); var new_functions_section: ?usize = null; var new_operands = std.ArrayList(u32).init(a); var emitted_ptrs = std.AutoHashMap(ResultId, void).init(a); @@ -347,38 +319,31 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void { // Re-emit the instruction, but replace all the IDs. - id_offsets.items.len = 0; - try parser.parseInstructionResultIds(binary.*, inst, &id_offsets); - new_operands.items.len = 0; try new_operands.appendSlice(inst.operands); - for (id_offsets.items) |offset| { - { - const id: ResultId = @enumFromInt(inst.operands[offset]); - if (replace.get(id)) |new_id| { - new_operands.items[offset] = @intFromEnum(new_id); - } + + for (new_operands.items, 0..) |*operand, i| { + const is_id = info.operand_is_id.isSet(inst.offset + 1 + i); + if (!is_id) continue; + + if (replace.get(@enumFromInt(operand.*))) |new_id| { + operand.* = @intFromEnum(new_id); } - // TODO: Does this logic work? Maybe it will emit an OpTypeForwardPointer to - // something thats not a struct... - // It seems to work correctly on behavior.zig at least - const id: ResultId = @enumFromInt(new_operands.items[offset]); + const id: ResultId = @enumFromInt(operand.*); + // TODO: This test is a little janky. Check the offset instead? if (maybe_result_id == null or maybe_result_id.? != id) { const index = info.entities.getIndex(id) orelse continue; const entity = info.entities.values()[index]; - if (entity.kind == .OpTypePointer) { - if (!emitted_ptrs.contains(id)) { - // The storage class is in the extra data - // TODO: This is kind of hacky... - const extra_data = info.extraDataByIndex(index); - const storage_class: spec.StorageClass = @enumFromInt(extra_data[0]); - try section.emit(a, .OpTypeForwardPointer, .{ - .pointer_type = id, - .storage_class = storage_class, - }); - try emitted_ptrs.put(id, {}); - } + if (entity.kind == .OpTypePointer and !emitted_ptrs.contains(id)) { + // Grab the pointer's storage class from its operands in the original + // module. + const storage_class: spec.StorageClass = @enumFromInt(binary.instructions[entity.first_operand + 1]); + try section.emit(a, .OpTypeForwardPointer, .{ + .pointer_type = id, + .storage_class = storage_class, + }); + try emitted_ptrs.put(id, {}); } } } From f5ab3c93c9a083b730e91e362001da7b32668938 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sat, 30 Mar 2024 10:06:55 +0100 Subject: [PATCH 3/5] spirv: handle annotations in deduplication pass --- src/link/SpirV/deduplicate.zig | 183 +++++++++++++++++++++++++++------ 1 file changed, 150 insertions(+), 33 deletions(-) diff --git a/src/link/SpirV/deduplicate.zig b/src/link/SpirV/deduplicate.zig index fafc608e957b..4a73276a9a02 100644 --- a/src/link/SpirV/deduplicate.zig +++ b/src/link/SpirV/deduplicate.zig @@ -17,7 +17,8 @@ fn canDeduplicate(opcode: Opcode) bool { // These are deprecated, so don't bother supporting them for now. return false; }, - .OpName, .OpMemberName => true, // Debug decoration-style instructions + // Debug decoration-style instructions + .OpName, .OpMemberName => true, else => switch (opcode.class()) { .TypeDeclaration, .ConstantCreation, @@ -44,6 +45,8 @@ const ModuleInfo = struct { /// or the entity that is affected by this entity if this entity /// is a decoration. result_id_index: u16, + /// The first decoration in `self.decorations`. + first_decoration: u32, }; /// Maps result-id to Entity's @@ -53,6 +56,8 @@ const ModuleInfo = struct { /// Because we need these values when recoding the module anyway, /// it contains the status of ALL operands in the module. operand_is_id: std.DynamicBitSetUnmanaged, + /// Store of decorations for each entity. + decorations: []const Entity, pub fn parse( arena: Allocator, @@ -62,6 +67,7 @@ const ModuleInfo = struct { var entities = std.AutoArrayHashMap(ResultId, Entity).init(arena); var id_offsets = std.ArrayList(u16).init(arena); var operand_is_id = try std.DynamicBitSetUnmanaged.initEmpty(arena, binary.instructions.len); + var decorations = std.MultiArrayList(struct { target_id: ResultId, entity: Entity }){}; var it = binary.iterateInstructions(); while (it.next()) |inst| { @@ -82,10 +88,20 @@ const ModuleInfo = struct { }; const result_id: ResultId = @enumFromInt(inst.operands[id_offsets.items[result_id_index]]); + const entity = Entity{ + .kind = inst.opcode, + .first_operand = first_operand_offset, + .num_operands = @intCast(inst.operands.len), + .result_id_index = result_id_index, + .first_decoration = undefined, // Filled in later + }; switch (inst.opcode.class()) { .Annotation, .Debug => { - // TODO + try decorations.append(arena, .{ + .target_id = result_id, + .entity = entity, + }); }, .TypeDeclaration, .ConstantCreation => { const entry = try entities.getOrPut(result_id); @@ -93,22 +109,67 @@ const ModuleInfo = struct { log.err("type or constant {} has duplicate definition", .{result_id}); return error.DuplicateId; } - entry.value_ptr.* = .{ - .kind = inst.opcode, - .first_operand = first_operand_offset, - .num_operands = @intCast(inst.operands.len), - .result_id_index = result_id_index, - }; + entry.value_ptr.* = entity; }, else => unreachable, } } + // Sort decorations by the index of the result-id in `entities. + // This ensures not only that the decorations of a particular reuslt-id + // are continuous, but the subsequences also appear in the same order as in `entities`. + + const SortContext = struct { + entities: std.AutoArrayHashMapUnmanaged(ResultId, Entity), + ids: []const ResultId, + + pub fn lessThan(ctx: @This(), a_index: usize, b_index: usize) bool { + // If any index is not in the entities set, its because its not a + // deduplicatable result-id. Those should be considered largest and + // float to the end. + const entity_index_a = ctx.entities.getIndex(ctx.ids[a_index]) orelse return false; + const entity_index_b = ctx.entities.getIndex(ctx.ids[b_index]) orelse return true; + + return entity_index_a < entity_index_b; + } + }; + + decorations.sort(SortContext{ + .entities = entities.unmanaged, + .ids = decorations.items(.target_id), + }); + + // Now go through the decorations and add the offsets to the entities list. + var decoration_i: u32 = 0; + const target_ids = decorations.items(.target_id); + for (entities.keys(), entities.values()) |id, *entity| { + entity.first_decoration = decoration_i; + + // Scan ahead to the next decoration + while (decoration_i < target_ids.len and target_ids[decoration_i] == id) { + decoration_i += 1; + } + } + return ModuleInfo{ .entities = entities.unmanaged, .operand_is_id = operand_is_id, + // There may be unrelated decorations at the end, so make sure to + // slice those off. + .decorations = decorations.items(.entity)[0..decoration_i], }; } + + fn entityDecorationsByIndex(self: ModuleInfo, index: usize) []const Entity { + const values = self.entities.values(); + const first_decoration = values[index].first_decoration; + if (index == values.len - 1) { + return self.decorations[first_decoration..]; + } else { + const next_first_decoration = values[index + 1].first_decoration; + return self.decorations[first_decoration..next_first_decoration]; + } + } }; const EntityContext = struct { @@ -138,23 +199,39 @@ const EntityContext = struct { return hasher.final(); } - fn hashInner(self: *EntityContext, hasher: *std.hash.Wyhash, id: ResultId) !void { - const index = self.info.entities.getIndex(id).?; + fn hashInner(self: *EntityContext, hasher: *std.hash.Wyhash, id: ResultId) error{OutOfMemory}!void { + const index = self.info.entities.getIndex(id) orelse { + // Index unknown, the type or constant may depend on another result-id + // that couldn't be deduplicated and so it wasn't added to info.entities. + // In this case, just has the ID itself. + std.hash.autoHash(hasher, id); + return; + }; + const entity = self.info.entities.values()[index]; - std.hash.autoHash(hasher, entity.kind); if (entity.kind == .OpTypePointer) { // This may be either a pointer that is forward-referenced in the future, // or a forward reference to a pointer. const entry = try self.ptr_map_a.getOrPut(self.a, id); if (entry.found_existing) { // Pointer already seen. Hash the index instead of recursing into its children. - // TODO: Discriminate this path somehow? std.hash.autoHash(hasher, entry.index); return; } } + try self.hashEntity(hasher, entity); + + // Process decorations. + const decorations = self.info.entityDecorationsByIndex(index); + for (decorations) |decoration| { + try self.hashEntity(hasher, decoration); + } + } + + fn hashEntity(self: *EntityContext, hasher: *std.hash.Wyhash, entity: ModuleInfo.Entity) !void { + std.hash.autoHash(hasher, entity.kind); // Process operands const operands = self.binary.instructions[entity.first_operand..][0..entity.num_operands]; for (operands, 0..) |operand, i| { @@ -178,19 +255,24 @@ const EntityContext = struct { return try self.eqlInner(a, b); } - fn eqlInner(self: *EntityContext, id_a: ResultId, id_b: ResultId) !bool { - const index_a = self.info.entities.getIndex(id_a).?; - const index_b = self.info.entities.getIndex(id_b).?; + fn eqlInner(self: *EntityContext, id_a: ResultId, id_b: ResultId) error{OutOfMemory}!bool { + const maybe_index_a = self.info.entities.getIndex(id_a); + const maybe_index_b = self.info.entities.getIndex(id_b); + + if (maybe_index_a == null and maybe_index_b == null) { + // Both indices unknown. In this case the type or constant + // may depend on another result-id that couldn't be deduplicated + // (so it wasn't added to info.entities). In this case, that particular + // result-id should be the same one. + return id_a == id_b; + } + + const index_a = maybe_index_a orelse return false; + const index_b = maybe_index_b orelse return false; const entity_a = self.info.entities.values()[index_a]; const entity_b = self.info.entities.values()[index_b]; - if (entity_a.kind != entity_b.kind) { - return false; - } else if (entity_a.result_id_index != entity_a.result_id_index) { - return false; - } - if (entity_a.kind == .OpTypePointer) { // May be a forward reference, or should be saved as a potential // forward reference in the future. Whatever the case, it should @@ -207,6 +289,33 @@ const EntityContext = struct { } } + if (!try self.eqlEntities(entity_a, entity_b)) { + return false; + } + + // Compare decorations. + const decorations_a = self.info.entityDecorationsByIndex(index_a); + const decorations_b = self.info.entityDecorationsByIndex(index_b); + if (decorations_a.len != decorations_b.len) { + return false; + } + + for (decorations_a, decorations_b) |decoration_a, decoration_b| { + if (!try self.eqlEntities(decoration_a, decoration_b)) { + return false; + } + } + + return true; + } + + fn eqlEntities(self: *EntityContext, entity_a: ModuleInfo.Entity, entity_b: ModuleInfo.Entity) !bool { + if (entity_a.kind != entity_b.kind) { + return false; + } else if (entity_a.result_id_index != entity_a.result_id_index) { + return false; + } + const operands_a = self.binary.instructions[entity_a.first_operand..][0..entity_a.num_operands]; const operands_b = self.binary.instructions[entity_b.first_operand..][0..entity_b.num_operands]; @@ -260,7 +369,6 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void { const a = arena.allocator(); const info = try ModuleInfo.parse(a, parser, binary.*); - log.info("added {} entities", .{info.entities.count()}); // Hash all keys once so that the maps can be allocated the right size. var ctx = EntityContext{ @@ -280,10 +388,9 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void { .entity_context = &ctx, }); var replace = std.AutoArrayHashMap(ResultId, ResultId).init(a); - for (info.entities.keys(), info.entities.values()) |id, entity| { + for (info.entities.keys()) |id| { const entry = try map.getOrPut(id); if (entry.found_existing) { - log.info("deduplicating {} - {s} (prior definition: {})", .{ id, @tagName(entity.kind), entry.key_ptr.* }); try replace.putNoClobber(id, entry.key_ptr.*); } } @@ -297,13 +404,15 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void { while (it.next()) |inst| { // Result-id can only be the first or second operand const inst_spec = parser.getInstSpec(inst.opcode).?; - const maybe_result_id: ?ResultId = for (0..2) |i| { + + const maybe_result_id_offset: ?u16 = for (0..2) |i| { if (inst_spec.operands.len > i and inst_spec.operands[i].kind == .IdResult) { - break @enumFromInt(inst.operands[i]); + break @intCast(i); } } else null; - if (maybe_result_id) |result_id| { + if (maybe_result_id_offset) |offset| { + const result_id: ResultId = @enumFromInt(inst.operands[offset]); if (replace.contains(result_id)) continue; } @@ -312,8 +421,16 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void { new_functions_section = section.instructions.items.len; }, .OpTypeForwardPointer => continue, // We re-emit these where needed - // TODO: These aren't supported yet, strip them out for testing purposes. - .OpName, .OpMemberName => continue, + else => {}, + } + + switch (inst.opcode.class()) { + .Annotation, .Debug => { + // For decoration-style instructions, only emit them + // if the target is not removed. + const target: ResultId = @enumFromInt(inst.operands[0]); + if (replace.contains(target)) continue; + }, else => {}, } @@ -330,9 +447,8 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void { operand.* = @intFromEnum(new_id); } - const id: ResultId = @enumFromInt(operand.*); - // TODO: This test is a little janky. Check the offset instead? - if (maybe_result_id == null or maybe_result_id.? != id) { + if (maybe_result_id_offset == null or maybe_result_id_offset.? != i) { + const id: ResultId = @enumFromInt(operand.*); const index = info.entities.getIndex(id) orelse continue; const entity = info.entities.values()[index]; if (entity.kind == .OpTypePointer and !emitted_ptrs.contains(id)) { @@ -349,7 +465,8 @@ pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule) !void { } if (inst.opcode == .OpTypePointer) { - try emitted_ptrs.put(maybe_result_id.?, {}); + const result_id: ResultId = @enumFromInt(new_operands.items[maybe_result_id_offset.?]); + try emitted_ptrs.put(result_id, {}); } try section.emitRawInstruction(a, inst.opcode, new_operands.items); From 12350f53bf5e39403e11b9902ca85bea97f65722 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Sat, 30 Mar 2024 18:30:28 +0100 Subject: [PATCH 4/5] spirv: clz, ctz for opencl This instruction seems common in compiler_rt. --- src/codegen/spirv.zig | 80 ++++++++++++++++++++++++++++++++++++++++++ test/behavior/math.zig | 3 -- 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 6b13f2623aa3..9113d72d927d 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -2332,6 +2332,9 @@ const DeclGen = struct { .mul_add => try self.airMulAdd(inst), + .ctz => try self.airClzCtz(inst, .ctz), + .clz => try self.airClzCtz(inst, .clz), + .splat => try self.airSplat(inst), .reduce, .reduce_optimized => try self.airReduce(inst), .shuffle => try self.airShuffle(inst), @@ -3029,6 +3032,83 @@ const DeclGen = struct { return try wip.finalize(); } + fn airClzCtz(self: *DeclGen, inst: Air.Inst.Index, op: enum { clz, ctz }) !?IdRef { + if (self.liveness.isUnused(inst)) return null; + + const mod = self.module; + const target = self.getTarget(); + const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; + const result_ty = self.typeOfIndex(inst); + const operand_ty = self.typeOf(ty_op.operand); + const operand = try self.resolve(ty_op.operand); + + const info = self.arithmeticTypeInfo(operand_ty); + switch (info.class) { + .composite_integer => unreachable, // TODO + .integer, .strange_integer => {}, + .float, .bool => unreachable, + } + + var wip = try self.elementWise(result_ty, false); + defer wip.deinit(); + + const elem_ty = if (wip.is_array) operand_ty.scalarType(mod) else operand_ty; + const elem_ty_ref = try self.resolveType(elem_ty, .direct); + const elem_ty_id = self.typeId(elem_ty_ref); + + for (wip.results, 0..) |*result_id, i| { + const elem = try wip.elementAt(operand_ty, operand, i); + + switch (target.os.tag) { + .opencl => { + const set = try self.spv.importInstructionSet(.@"OpenCL.std"); + const ext_inst: u32 = switch (op) { + .clz => 151, // clz + .ctz => 152, // ctz + }; + + // Note: result of OpenCL ctz/clz returns operand_ty, and we want result_ty. + // result_ty is always large enough to hold the result, so we might have to down + // cast it. + const tmp = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpExtInst, .{ + .id_result_type = elem_ty_id, + .id_result = tmp, + .set = set, + .instruction = .{ .inst = ext_inst }, + .id_ref_4 = &.{elem}, + }); + + if (wip.ty_id == elem_ty_id) { + result_id.* = tmp; + continue; + } + + result_id.* = self.spv.allocId(); + if (result_ty.scalarType(mod).isSignedInt(mod)) { + assert(elem_ty.scalarType(mod).isSignedInt(mod)); + try self.func.body.emit(self.spv.gpa, .OpSConvert, .{ + .id_result_type = wip.ty_id, + .id_result = result_id.*, + .signed_value = tmp, + }); + } else { + assert(elem_ty.scalarType(mod).isUnsignedInt(mod)); + try self.func.body.emit(self.spv.gpa, .OpUConvert, .{ + .id_result_type = wip.ty_id, + .id_result = result_id.*, + .unsigned_value = tmp, + }); + } + }, + .vulkan => unreachable, // TODO + else => unreachable, + } + } + + return try wip.finalize(); + } + fn airSplat(self: *DeclGen, inst: Air.Inst.Index) !?IdRef { const ty_op = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; const operand_id = try self.resolve(ty_op.operand); diff --git a/test/behavior/math.zig b/test/behavior/math.zig index 092924af06ed..fbd8369219c6 100644 --- a/test/behavior/math.zig +++ b/test/behavior/math.zig @@ -65,7 +65,6 @@ test "@clz" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try testClz(); try comptime testClz(); @@ -148,7 +147,6 @@ test "@ctz" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; try testCtz(); try comptime testCtz(); @@ -1752,7 +1750,6 @@ test "@clz works on both vector and scalar inputs" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO - if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; var x: u32 = 0x1; _ = &x; From 27b91288dc3c0442b645e06cae75c076600bdfd3 Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Fri, 29 Mar 2024 13:54:40 +0100 Subject: [PATCH 5/5] spirv: disable failing tests --- test/behavior/destructure.zig | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/behavior/destructure.zig b/test/behavior/destructure.zig index 43ddbb7a4de0..78ee999ddb12 100644 --- a/test/behavior/destructure.zig +++ b/test/behavior/destructure.zig @@ -23,6 +23,8 @@ test "simple destructure" { } test "destructure with comptime syntax" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + const S = struct { fn doTheTest() !void { {