Skip to content

Commit

Permalink
nix component: build hook: add expectPacket() and error sets to sig…
Browse files Browse the repository at this point in the history
…natures
  • Loading branch information
dermetfan committed Jul 15, 2024
1 parent 6a478d4 commit 52aaaed
Showing 1 changed file with 36 additions and 9 deletions.
45 changes: 36 additions & 9 deletions src/components/nix/build-hook/protocol.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@ const std = @import("std");

pub const block_len = 8;

pub const PaddingError = error{
/// The padding contains bytes that are not zeroes.
BadPadding,

/// The stream ended before the expected amount of padding could be read.
EndOfStream,
};

pub fn ReadError(comptime Reader: type, allocates: bool) type {
var set = Reader.NoEofError || PaddingError;
if (allocates) set = set || std.mem.Allocator.Error;
return set;
}

/// Returns the number of padding bytes.
pub fn padding(len: usize) std.math.IntFittingRange(0, block_len) {
return if (len % block_len == 0) 0 else @intCast(block_len - len % block_len);
Expand All @@ -15,13 +29,13 @@ test padding {
}

/// Reads the padding for the given length and asserts it is all zeroes.
pub fn readPadding(reader: anytype, len: usize) !void {
pub fn readPadding(reader: anytype, len: usize) ReadError(@TypeOf(reader), false)!void {
const padding_len = padding(len);
if (padding_len == 0) return;

var padding_buf: [block_len]u8 = undefined;
const padding_slice = padding_buf[0..padding_len];
if (try reader.readAll(padding_slice) < padding_slice.len) return error.EndOfStream;
try reader.readNoEof(padding_slice);

if (!std.mem.allEqual(u8, padding_slice, 0)) return error.BadPadding;
}
Expand All @@ -43,7 +57,7 @@ test readPadding {
}

/// Fills the buffer and discards the padding.
pub fn readPadded(reader: anytype, buf: []u8) !void {
pub fn readPadded(reader: anytype, buf: []u8) ReadError(@TypeOf(reader), false)!void {
if (try reader.readAll(buf) < buf.len) return error.EndOfStream;
try readPadding(reader, buf.len);
}
Expand All @@ -64,26 +78,26 @@ test readPadded {
}
}

pub fn readU64(reader: anytype) !u64 {
pub fn readU64(reader: anytype) ReadError(@TypeOf(reader), false)!u64 {
return reader.readInt(u64, .little);
}

pub fn readBool(reader: anytype) !bool {
pub fn readBool(reader: anytype) (ReadError(@TypeOf(reader), false) || error{BadBool})!bool {
return switch (try readU64(reader)) {
0 => false,
1 => true,
else => error.BadBool,
};
}

pub fn readPacket(allocator: std.mem.Allocator, reader: anytype) ![]const u8 {
pub fn readPacket(allocator: std.mem.Allocator, reader: anytype) ReadError(@TypeOf(reader), true)![]const u8 {
const buf = try allocator.alloc(u8, try readU64(reader));
errdefer allocator.free(buf);
try readPadded(reader, buf);
return buf;
}

pub fn readPackets(allocator: std.mem.Allocator, reader: anytype) ![]const []const u8 {
pub fn readPackets(allocator: std.mem.Allocator, reader: anytype) ReadError(@TypeOf(reader), true)![]const []const u8 {
const bufs = try allocator.alloc([]const u8, try readU64(reader));
errdefer {
for (bufs) |buf| allocator.free(buf);
Expand All @@ -93,7 +107,7 @@ pub fn readPackets(allocator: std.mem.Allocator, reader: anytype) ![]const []con
return bufs;
}

pub fn readStringStringMap(allocator: std.mem.Allocator, reader: anytype) !std.BufMap {
pub fn readStringStringMap(allocator: std.mem.Allocator, reader: anytype) ReadError(@TypeOf(reader), true)!std.BufMap {
// We build the `std.BufMap`'s underlying `.hash_map` directly
// so that we don't have to copy keys and values twice:
// Once when reading and once when inserting.
Expand All @@ -118,7 +132,7 @@ pub fn readStringStringMap(allocator: std.mem.Allocator, reader: anytype) !std.B
}

/// Reads fields in declaration order.
pub fn readStruct(comptime T: type, allocator: std.mem.Allocator, reader: anytype) !T {
pub fn readStruct(comptime T: type, allocator: std.mem.Allocator, reader: anytype) (ReadError(@TypeOf(reader), true) || error{BadBool})!T {
var strukt: T = undefined;

const fields = @typeInfo(T).Struct.fields;
Expand Down Expand Up @@ -154,3 +168,16 @@ pub fn readStruct(comptime T: type, allocator: std.mem.Allocator, reader: anytyp

return strukt;
}

pub fn expectPacket(comptime expected: []const u8, reader: anytype) (ReadError(@TypeOf(reader), true) || error{UnexpectedPacket})!void {
var buf: [expected.len]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&buf);

const packet = readPacket(fba.allocator(), reader) catch |err| return switch (err) {
error.OutOfMemory => error.UnexpectedPacket,
else => err,
};

if (!std.mem.eql(u8, packet, expected))
return error.UnexpectedPacket;
}

0 comments on commit 52aaaed

Please sign in to comment.