From aa57513735dcce88179cf2401703dcdca3d77c54 Mon Sep 17 00:00:00 2001 From: Lorenz Bauer Date: Tue, 12 Sep 2023 16:28:17 +0100 Subject: [PATCH] map: zero-allocation operations for common types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Map keys and values are currently marshaled into []byte by souped up versions of binary.Write and binary.Read. This allows users to be blissfully unaware of compiler inserted padding on the Go side. This is wasteful in case the Go in-memory representation matches what the kernel expects because we need additional allocations. Refactor syscall marshaling into a new package sysenc which encapsulates the logic we need to determine whether a Go type is safe for zero-allocation / zero-copy marshaling. The type must be a pointer to or a slice of: * A primitive type like uint32, ... or * An array of valid types or * A struct made up of valid types without any compiler inserted padding between fields Per-CPU maps don't support zero-allocation operations for now, but the new code already makes things a little bit cheaper. Structs with trailing padding also don't benefit from the optimization for now. Consider type padded struct { A uint32; B uint16 } Allowing such a type creates an edge case: make([]padding, 1) uses zero-allocation marshaling while make([]padding, 2) doesn't, due to interior padding. It's simpler to skip such types for now. goos: linux goarch: amd64 pkg: github.com/cilium/ebpf cpu: 12th Gen Intel(R) Core(TM) i7-1260P │ unsafe.txt │ │ sec/op │ Marshaling/ValueUnmarshalReflect-16 356.1n ± 2% Marshaling/KeyMarshalReflect-16 368.6n ± 1% Marshaling/ValueBinaryUnmarshaler-16 378.6n ± 2% Marshaling/KeyBinaryMarshaler-16 356.2n ± 1% Marshaling/KeyValueUnsafe-16 328.0n ± 2% PerCPUMarshalling/reflection-16 1.232µ ± 1% │ unsafe.txt │ │ B/op │ Marshaling/ValueUnmarshalReflect-16 0.000 ± 0% Marshaling/KeyMarshalReflect-16 0.000 ± 0% Marshaling/ValueBinaryUnmarshaler-16 24.00 ± 0% Marshaling/KeyBinaryMarshaler-16 8.000 ± 0% Marshaling/KeyValueUnsafe-16 0.000 ± 0% PerCPUMarshalling/reflection-16 280.0 ± 0% │ unsafe.txt │ │ allocs/op │ Marshaling/ValueUnmarshalReflect-16 0.000 ± 0% Marshaling/KeyMarshalReflect-16 0.000 ± 0% Marshaling/ValueBinaryUnmarshaler-16 1.000 ± 0% Marshaling/KeyBinaryMarshaler-16 1.000 ± 0% Marshaling/KeyValueUnsafe-16 0.000 ± 0% PerCPUMarshalling/reflection-16 3.000 ± 0% Signed-off-by: Lorenz Bauer --- collection.go | 5 +- internal/endian_be.go | 2 +- internal/endian_le.go | 2 +- internal/sysenc/buffer.go | 77 ++++++++ internal/sysenc/buffer_test.go | 27 +++ internal/sysenc/doc.go | 3 + internal/sysenc/layout.go | 41 +++++ internal/sysenc/layout_test.go | 33 ++++ internal/sysenc/marshal.go | 163 +++++++++++++++++ internal/sysenc/marshal_test.go | 306 ++++++++++++++++++++++++++++++++ map.go | 75 ++++---- marshalers.go | 162 +++-------------- prog.go | 9 +- prog_test.go | 2 +- syscalls.go | 4 +- 15 files changed, 715 insertions(+), 196 deletions(-) create mode 100644 internal/sysenc/buffer.go create mode 100644 internal/sysenc/buffer_test.go create mode 100644 internal/sysenc/doc.go create mode 100644 internal/sysenc/layout.go create mode 100644 internal/sysenc/layout_test.go create mode 100644 internal/sysenc/marshal.go create mode 100644 internal/sysenc/marshal_test.go diff --git a/collection.go b/collection.go index fb720bebd..8e66336c9 100644 --- a/collection.go +++ b/collection.go @@ -11,6 +11,7 @@ import ( "github.com/cilium/ebpf/btf" "github.com/cilium/ebpf/internal" "github.com/cilium/ebpf/internal/kconfig" + "github.com/cilium/ebpf/internal/sysenc" ) // CollectionOptions control loading a collection into the kernel. @@ -175,12 +176,12 @@ func (cs *CollectionSpec) RewriteConstants(consts map[string]interface{}) error return fmt.Errorf("section %s: offset %d(+%d) for variable %s is out of bounds", name, v.Offset, v.Size, vname) } - b, err := marshalBytes(replacement, int(v.Size)) + b, err := sysenc.Marshal(replacement, int(v.Size)) if err != nil { return fmt.Errorf("marshaling constant replacement %s: %w", vname, err) } - copy(cpy[v.Offset:v.Offset+v.Size], b) + b.CopyTo(cpy[v.Offset : v.Offset+v.Size]) replaced[vname] = true } diff --git a/internal/endian_be.go b/internal/endian_be.go index 96a2ac0de..39f49ba3a 100644 --- a/internal/endian_be.go +++ b/internal/endian_be.go @@ -6,7 +6,7 @@ import "encoding/binary" // NativeEndian is set to either binary.BigEndian or binary.LittleEndian, // depending on the host's endianness. -var NativeEndian binary.ByteOrder = binary.BigEndian +var NativeEndian = binary.BigEndian // ClangEndian is set to either "el" or "eb" depending on the host's endianness. const ClangEndian = "eb" diff --git a/internal/endian_le.go b/internal/endian_le.go index fde4c55a6..9488e301b 100644 --- a/internal/endian_le.go +++ b/internal/endian_le.go @@ -6,7 +6,7 @@ import "encoding/binary" // NativeEndian is set to either binary.BigEndian or binary.LittleEndian, // depending on the host's endianness. -var NativeEndian binary.ByteOrder = binary.LittleEndian +var NativeEndian = binary.LittleEndian // ClangEndian is set to either "el" or "eb" depending on the host's endianness. const ClangEndian = "el" diff --git a/internal/sysenc/buffer.go b/internal/sysenc/buffer.go new file mode 100644 index 000000000..c6959d9cc --- /dev/null +++ b/internal/sysenc/buffer.go @@ -0,0 +1,77 @@ +package sysenc + +import ( + "unsafe" + + "github.com/cilium/ebpf/internal/sys" +) + +type Buffer struct { + ptr unsafe.Pointer + // Size of the buffer. syscallPointerOnly if created from UnsafeBuffer or when using + // zero-copy unmarshaling. + size int +} + +const syscallPointerOnly = -1 + +func newBuffer(buf []byte) Buffer { + if len(buf) == 0 { + return Buffer{} + } + return Buffer{unsafe.Pointer(&buf[0]), len(buf)} +} + +// UnsafeBuffer constructs a Buffer for zero-copy unmarshaling. +// +// [Pointer] is the only valid method to call on such a Buffer. +// Use [SyscallBuffer] instead if possible. +func UnsafeBuffer(ptr unsafe.Pointer) Buffer { + return Buffer{ptr, syscallPointerOnly} +} + +// SyscallOutput prepares a Buffer for a syscall to write into. +// +// The buffer may point at the underlying memory of dst, in which case [Unmarshal] +// becomes a no-op. +// +// The contents of the buffer are undefined and may be non-zero. +func SyscallOutput(dst any, size int) Buffer { + if dstBuf := unsafeBackingMemory(dst); len(dstBuf) == size { + buf := newBuffer(dstBuf) + buf.size = syscallPointerOnly + return buf + } + + return newBuffer(make([]byte, size)) +} + +// CopyTo copies the buffer into dst. +// +// Returns the number of copied bytes. +func (b Buffer) CopyTo(dst []byte) int { + return copy(dst, b.unsafeBytes()) +} + +// Pointer returns the location where a syscall should write. +func (b Buffer) Pointer() sys.Pointer { + // NB: This deliberately ignores b.length to support zero-copy + // marshaling / unmarshaling using unsafe.Pointer. + return sys.NewPointer(b.ptr) +} + +// Unmarshal the buffer into the provided value. +func (b Buffer) Unmarshal(data any) error { + if b.size == syscallPointerOnly { + return nil + } + + return Unmarshal(data, b.unsafeBytes()) +} + +func (b Buffer) unsafeBytes() []byte { + if b.size == syscallPointerOnly { + return nil + } + return unsafe.Slice((*byte)(b.ptr), b.size) +} diff --git a/internal/sysenc/buffer_test.go b/internal/sysenc/buffer_test.go new file mode 100644 index 000000000..d6fc64e8c --- /dev/null +++ b/internal/sysenc/buffer_test.go @@ -0,0 +1,27 @@ +package sysenc_test + +import ( + "testing" + "unsafe" + + "github.com/cilium/ebpf/internal/sys" + "github.com/cilium/ebpf/internal/sysenc" + qt "github.com/frankban/quicktest" +) + +func TestZeroBuffer(t *testing.T) { + var zero sysenc.Buffer + + qt.Assert(t, zero.CopyTo(make([]byte, 1)), qt.Equals, 0) + qt.Assert(t, zero.Pointer(), qt.Equals, sys.Pointer{}) + qt.Assert(t, zero.Unmarshal(new(uint16)), qt.IsNotNil) +} + +func TestUnsafeBuffer(t *testing.T) { + ptr := unsafe.Pointer(new(uint16)) + buf := sysenc.UnsafeBuffer(ptr) + + qt.Assert(t, buf.CopyTo(make([]byte, 1)), qt.Equals, 0) + qt.Assert(t, buf.Pointer(), qt.Equals, sys.NewPointer(ptr)) + qt.Assert(t, buf.Unmarshal(new(uint16)), qt.IsNil) +} diff --git a/internal/sysenc/doc.go b/internal/sysenc/doc.go new file mode 100644 index 000000000..676ad98ba --- /dev/null +++ b/internal/sysenc/doc.go @@ -0,0 +1,3 @@ +// Package sysenc provides efficient conversion of Go values to system +// call interfaces. +package sysenc diff --git a/internal/sysenc/layout.go b/internal/sysenc/layout.go new file mode 100644 index 000000000..52d111e7a --- /dev/null +++ b/internal/sysenc/layout.go @@ -0,0 +1,41 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found at https://go.dev/LICENSE. + +package sysenc + +import ( + "reflect" + "sync" +) + +var hasUnexportedFieldsCache sync.Map // map[reflect.Type]bool + +func hasUnexportedFields(typ reflect.Type) bool { + switch typ.Kind() { + case reflect.Slice, reflect.Array, reflect.Pointer: + return hasUnexportedFields(typ.Elem()) + + case reflect.Struct: + if unexported, ok := hasUnexportedFieldsCache.Load(typ); ok { + return unexported.(bool) + } + + unexported := false + for i, n := 0, typ.NumField(); i < n; i++ { + field := typ.Field(i) + // Package binary allows _ fields but always writes zeroes into them. + if (!field.IsExported() && field.Name != "_") || hasUnexportedFields(field.Type) { + unexported = true + break + } + } + + hasUnexportedFieldsCache.Store(typ, unexported) + return unexported + + default: + // NB: It's not clear what this means for Chan and so on. + return false + } +} diff --git a/internal/sysenc/layout_test.go b/internal/sysenc/layout_test.go new file mode 100644 index 000000000..641362f98 --- /dev/null +++ b/internal/sysenc/layout_test.go @@ -0,0 +1,33 @@ +package sysenc + +import ( + "fmt" + "reflect" + "testing" + + qt "github.com/frankban/quicktest" +) + +func TestHasUnexportedFields(t *testing.T) { + for _, test := range []struct { + value any + result bool + }{ + {struct{ A any }{}, false}, + {(*struct{ A any })(nil), false}, + {([]struct{ A any })(nil), false}, + {struct{ _ any }{}, false}, + {struct{ _ struct{ a any } }{}, true}, + {(*struct{ _ any })(nil), false}, + {([]struct{ _ any })(nil), false}, + {struct{ a any }{}, true}, + {(*struct{ a any })(nil), true}, + {([]struct{ a any })(nil), true}, + {(*struct{ A []struct{ a any } })(nil), true}, + } { + t.Run(fmt.Sprintf("%T", test.value), func(t *testing.T) { + have := hasUnexportedFields(reflect.TypeOf(test.value)) + qt.Assert(t, have, qt.Equals, test.result) + }) + } +} diff --git a/internal/sysenc/marshal.go b/internal/sysenc/marshal.go new file mode 100644 index 000000000..235a1df26 --- /dev/null +++ b/internal/sysenc/marshal.go @@ -0,0 +1,163 @@ +package sysenc + +import ( + "bytes" + "encoding" + "encoding/binary" + "errors" + "fmt" + "reflect" + "sync" + "unsafe" + + "github.com/cilium/ebpf/internal" +) + +// Marshal turns data into a byte slice using the system's native endianness. +// +// If possible, avoids allocations by directly using the backing memory +// of data. This means that the variable must not be modified for the lifetime +// of the returned [Buffer]. +// +// Returns an error if the data can't be turned into a byte slice according to +// the behaviour of [binary.Write]. +func Marshal(data any, size int) (Buffer, error) { + if data == nil { + return Buffer{}, errors.New("can't marshal a nil value") + } + + var buf []byte + var err error + switch value := data.(type) { + case encoding.BinaryMarshaler: + buf, err = value.MarshalBinary() + case string: + buf = unsafe.Slice(unsafe.StringData(value), len(value)) + case []byte: + buf = value + case int16: + buf = internal.NativeEndian.AppendUint16(make([]byte, 0, 2), uint16(value)) + case uint16: + buf = internal.NativeEndian.AppendUint16(make([]byte, 0, 2), value) + case int32: + buf = internal.NativeEndian.AppendUint32(make([]byte, 0, 4), uint32(value)) + case uint32: + buf = internal.NativeEndian.AppendUint32(make([]byte, 0, 4), value) + case int64: + buf = internal.NativeEndian.AppendUint64(make([]byte, 0, 8), uint64(value)) + case uint64: + buf = internal.NativeEndian.AppendUint64(make([]byte, 0, 8), value) + default: + if buf := unsafeBackingMemory(data); len(buf) == size { + return newBuffer(buf), nil + } + + wr := internal.NewBuffer(make([]byte, 0, size)) + defer internal.PutBuffer(wr) + + err = binary.Write(wr, internal.NativeEndian, value) + buf = wr.Bytes() + } + if err != nil { + return Buffer{}, err + } + + if len(buf) != size { + return Buffer{}, fmt.Errorf("%T doesn't marshal to %d bytes", data, size) + } + + return newBuffer(buf), nil +} + +var bytesReaderPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Reader) + }, +} + +// Unmarshal a byte slice in the system's native endianness into data. +// +// Returns an error if buf can't be unmarshalled according to the behaviour +// of [binary.Read]. +func Unmarshal(data interface{}, buf []byte) error { + switch value := data.(type) { + case encoding.BinaryUnmarshaler: + return value.UnmarshalBinary(buf) + + case *string: + *value = string(buf) + return nil + + default: + if dataBuf := unsafeBackingMemory(data); len(dataBuf) == len(buf) { + copy(dataBuf, buf) + return nil + } + + rd := bytesReaderPool.Get().(*bytes.Reader) + defer bytesReaderPool.Put(rd) + + rd.Reset(buf) + + return binary.Read(rd, internal.NativeEndian, value) + } +} + +// unsafeBackingMemory returns the backing memory of data if it can be used +// instead of calling into package binary. +// +// Returns nil if the value is not a pointer or a slice, or if it contains +// padding or unexported fields. +func unsafeBackingMemory(data any) []byte { + if data == nil { + return nil + } + + value := reflect.ValueOf(data) + var valueSize int + switch value.Kind() { + case reflect.Pointer: + if value.IsNil() { + return nil + } + + if elemType := value.Type().Elem(); elemType.Kind() != reflect.Slice { + valueSize = int(elemType.Size()) + break + } + + // We're dealing with a pointer to a slice. Dereference and + // handle it like a regular slice. + value = value.Elem() + fallthrough + + case reflect.Slice: + valueSize = int(value.Type().Elem().Size()) * value.Len() + + default: + // Prevent Value.UnsafePointer from panicking. + return nil + } + + // Some nil pointer types currently crash binary.Size. Call it after our own + // code so that the panic isn't reachable. + // See https://github.com/golang/go/issues/60892 + if size := binary.Size(data); size == -1 || size != valueSize { + // The type contains padding or unsupported types. + return nil + } + + if hasUnexportedFields(reflect.TypeOf(data)) { + return nil + } + + // Reinterpret the pointer as a byte slice. This violates the unsafe.Pointer + // rules because it's very unlikely that the source data has "an equivalent + // memory layout". However, we can make it safe-ish because of the + // following reasons: + // - There is no alignment mismatch since we cast to a type with an + // alignment of 1. + // - There are no pointers in the source type so we don't upset the GC. + // - The length is verified at runtime. + return unsafe.Slice((*byte)(value.UnsafePointer()), valueSize) +} diff --git a/internal/sysenc/marshal_test.go b/internal/sysenc/marshal_test.go new file mode 100644 index 000000000..95f7651cd --- /dev/null +++ b/internal/sysenc/marshal_test.go @@ -0,0 +1,306 @@ +package sysenc + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" + "reflect" + "testing" + + "github.com/cilium/ebpf/internal" + qt "github.com/frankban/quicktest" + "github.com/google/go-cmp/cmp/cmpopts" +) + +type testcase struct { + new func() any + zeroAllocs bool +} + +type struc struct { + A uint64 + B uint32 +} + +type explicitPad struct { + _ uint32 +} + +func testcases() []testcase { + return []testcase{ + {func() any { return new([1]uint64) }, true}, + {func() any { return new(int16) }, true}, + {func() any { return new(uint16) }, true}, + {func() any { return new(int32) }, true}, + {func() any { return new(uint32) }, true}, + {func() any { return new(int64) }, true}, + {func() any { return new(uint64) }, true}, + {func() any { return make([]byte, 9) }, true}, + {func() any { return new(explicitPad) }, true}, + {func() any { return make([]explicitPad, 0) }, false}, + {func() any { return make([]explicitPad, 1) }, false}, + {func() any { return make([]explicitPad, 2) }, false}, + {func() any { return new(struc) }, false}, + {func() any { return make([]struc, 0) }, false}, + {func() any { return make([]struc, 1) }, false}, + {func() any { return make([]struc, 2) }, false}, + {func() any { return int16(math.MaxInt16) }, false}, + {func() any { return uint16(math.MaxUint16) }, false}, + {func() any { return int32(math.MaxInt32) }, false}, + {func() any { return uint32(math.MaxUint32) }, false}, + {func() any { return int64(math.MaxInt64) }, false}, + {func() any { return uint64(math.MaxUint64) }, false}, + {func() any { return struc{math.MaxUint64, math.MaxUint32} }, false}, + } +} + +func TestMarshal(t *testing.T) { + for _, test := range testcases() { + value := test.new() + t.Run(fmt.Sprintf("%T", value), func(t *testing.T) { + var want bytes.Buffer + if err := binary.Write(&want, internal.NativeEndian, value); err != nil { + t.Fatal(err) + } + + have := make([]byte, want.Len()) + buf, err := Marshal(value, binary.Size(value)) + if err != nil { + t.Fatal(err) + } + qt.Assert(t, buf.CopyTo(have), qt.Equals, want.Len()) + qt.Assert(t, have, qt.CmpEquals(cmpopts.EquateEmpty()), want.Bytes()) + }) + } +} + +func TestMarshalAllocations(t *testing.T) { + allocationsPerMarshal := func(t *testing.T, data any) float64 { + size := binary.Size(data) + return testing.AllocsPerRun(5, func() { + _, err := Marshal(data, size) + if err != nil { + t.Fatal(err) + } + }) + } + + for _, test := range testcases() { + if !test.zeroAllocs { + continue + } + + value := test.new() + t.Run(fmt.Sprintf("%T", value), func(t *testing.T) { + qt.Assert(t, allocationsPerMarshal(t, value), qt.Equals, float64(0)) + }) + } +} + +func TestUnmarshal(t *testing.T) { + for _, test := range testcases() { + value := test.new() + if !canUnmarshalInto(value) { + continue + } + + t.Run(fmt.Sprintf("%T", value), func(t *testing.T) { + want := test.new() + buf := randomiseValue(t, want) + + qt.Assert(t, Unmarshal(value, buf), qt.IsNil) + qt.Assert(t, value, qt.DeepEquals, want) + }) + } +} + +func TestUnmarshalAllocations(t *testing.T) { + allocationsPerUnmarshal := func(t *testing.T, data any, buf []byte) float64 { + return testing.AllocsPerRun(5, func() { + err := Unmarshal(data, buf) + if err != nil { + t.Fatal(err) + } + }) + } + + for _, test := range testcases() { + if !test.zeroAllocs { + continue + } + + value := test.new() + if !canUnmarshalInto(value) { + continue + } + + t.Run(fmt.Sprintf("%T", value), func(t *testing.T) { + buf := make([]byte, binary.Size(value)) + qt.Assert(t, allocationsPerUnmarshal(t, value, buf), qt.Equals, float64(0)) + }) + } +} + +func TestUnsafeBackingMemory(t *testing.T) { + marshalNative := func(t *testing.T, data any) []byte { + t.Helper() + + var buf bytes.Buffer + qt.Assert(t, binary.Write(&buf, internal.NativeEndian, data), qt.IsNil) + return buf.Bytes() + } + + for _, test := range []struct { + name string + value any + }{ + { + "slice", + []uint32{1, 2}, + }, + { + "pointer to slice", + &[]uint32{2}, + }, + { + "pointer to array", + &[2]uint64{}, + }, + { + "pointer to int64", + new(int64), + }, + { + "pointer to struct", + &struct { + A, B uint16 + C uint32 + }{}, + }, + { + "struct with explicit padding", + &struct{ _ uint64 }{}, + }, + } { + t.Run("valid: "+test.name, func(t *testing.T) { + want := marshalNative(t, test.value) + have := unsafeBackingMemory(test.value) + qt.Assert(t, have, qt.DeepEquals, want) + }) + } + + for _, test := range []struct { + name string + value any + }{ + { + "nil", + nil, + }, + { + "nil slice", + ([]byte)(nil), + }, + { + "nil pointer", + (*uint64)(nil), + }, + { + "nil pointer to slice", + (*[]uint32)(nil), + }, + { + "nil pointer to array", + (*[2]uint64)(nil), + }, + { + "unexported field", + &struct{ a uint64 }{}, + }, + { + "struct containing pointer", + &struct{ A *uint64 }{}, + }, + { + "struct with trailing padding", + &struc{}, + }, + { + "struct with interspersed padding", + &struct { + B uint32 + A uint64 + }{}, + }, + { + "padding between slice entries", + &[]struc{{}}, + }, + { + "padding between array entries", + &[2]struc{}, + }, + } { + t.Run("invalid: "+test.name, func(t *testing.T) { + qt.Assert(t, unsafeBackingMemory(test.value), qt.IsNil) + }) + } +} + +func BenchmarkMarshal(b *testing.B) { + for _, test := range testcases() { + value := test.new() + b.Run(fmt.Sprintf("%T", value), func(b *testing.B) { + size := binary.Size(value) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _ = Marshal(value, size) + } + }) + } +} + +func BenchmarkUnmarshal(b *testing.B) { + for _, test := range testcases() { + value := test.new() + if !canUnmarshalInto(value) { + continue + } + + b.Run(fmt.Sprintf("%T", value), func(b *testing.B) { + size := binary.Size(value) + buf := make([]byte, size) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = Unmarshal(value, buf) + } + }) + } +} + +func randomiseValue(tb testing.TB, value any) []byte { + tb.Helper() + + size := binary.Size(value) + if size == -1 { + tb.Fatalf("Can't unmarshal into %T", value) + } + + buf := make([]byte, size) + for i := range buf { + buf[i] = byte(i) + } + + err := binary.Read(bytes.NewReader(buf), internal.NativeEndian, value) + qt.Assert(tb, err, qt.IsNil) + + return buf +} + +func canUnmarshalInto(data any) bool { + kind := reflect.TypeOf(data).Kind() + return kind == reflect.Slice || kind == reflect.Pointer +} diff --git a/map.go b/map.go index b2c280639..be732a24f 100644 --- a/map.go +++ b/map.go @@ -15,6 +15,7 @@ import ( "github.com/cilium/ebpf/btf" "github.com/cilium/ebpf/internal" "github.com/cilium/ebpf/internal/sys" + "github.com/cilium/ebpf/internal/sysenc" "github.com/cilium/ebpf/internal/unix" ) @@ -584,8 +585,8 @@ func (m *Map) LookupWithFlags(key, valueOut interface{}, flags MapLookupFlags) e return m.lookupPerCPU(key, valueOut, flags) } - valuePtr, valueBytes := makeBuffer(valueOut, m.fullValueSize) - if err := m.lookup(key, valuePtr, flags); err != nil { + valueBytes := makeMapSyscallOutput(valueOut, m.fullValueSize) + if err := m.lookup(key, valueBytes.Pointer(), flags); err != nil { return err } @@ -611,8 +612,8 @@ func (m *Map) LookupAndDeleteWithFlags(key, valueOut interface{}, flags MapLooku return m.lookupAndDeletePerCPU(key, valueOut, flags) } - valuePtr, valueBytes := makeBuffer(valueOut, m.fullValueSize) - if err := m.lookupAndDelete(key, valuePtr, flags); err != nil { + valueBytes := makeMapSyscallOutput(valueOut, m.fullValueSize) + if err := m.lookupAndDelete(key, valueBytes.Pointer(), flags); err != nil { return err } return m.unmarshalValue(valueOut, valueBytes) @@ -780,13 +781,13 @@ func (m *Map) Delete(key interface{}) error { // // Returns ErrKeyNotExist if there is no next key. func (m *Map) NextKey(key, nextKeyOut interface{}) error { - nextKeyPtr, nextKeyBytes := makeBuffer(nextKeyOut, int(m.keySize)) + nextKeyBytes := makeMapSyscallOutput(nextKeyOut, int(m.keySize)) - if err := m.nextKey(key, nextKeyPtr); err != nil { + if err := m.nextKey(key, nextKeyBytes.Pointer()); err != nil { return err } - if err := m.unmarshalKey(nextKeyOut, nextKeyBytes); err != nil { + if err := nextKeyBytes.Unmarshal(nextKeyOut); err != nil { return fmt.Errorf("can't unmarshal next key: %w", err) } return nil @@ -957,14 +958,14 @@ func (m *Map) batchLookup(cmd sys.Cmd, startKey, nextKeyOut, keysOut, valuesOut keyPtr := sys.NewSlicePointer(keyBuf) valueBuf := make([]byte, count*int(m.fullValueSize)) valuePtr := sys.NewSlicePointer(valueBuf) - nextPtr, nextBuf := makeBuffer(nextKeyOut, int(m.keySize)) + nextBuf := makeMapSyscallOutput(nextKeyOut, int(m.keySize)) attr := sys.MapLookupBatchAttr{ MapFd: m.fd.Uint(), Keys: keyPtr, Values: valuePtr, Count: uint32(count), - OutBatch: nextPtr, + OutBatch: nextBuf.Pointer(), } if opts != nil { @@ -974,7 +975,7 @@ func (m *Map) batchLookup(cmd sys.Cmd, startKey, nextKeyOut, keysOut, valuesOut var err error if startKey != nil { - attr.InBatch, err = marshalPtr(startKey, int(m.keySize)) + attr.InBatch, err = marshalMapSyscallInput(startKey, int(m.keySize)) if err != nil { return 0, err } @@ -986,15 +987,15 @@ func (m *Map) batchLookup(cmd sys.Cmd, startKey, nextKeyOut, keysOut, valuesOut return 0, sysErr } - err = m.unmarshalKey(nextKeyOut, nextBuf) + err = nextBuf.Unmarshal(nextKeyOut) if err != nil { return 0, err } - err = unmarshalBytes(keysOut, keyBuf) + err = sysenc.Unmarshal(keysOut, keyBuf) if err != nil { return 0, err } - err = unmarshalBytes(valuesOut, valueBuf) + err = sysenc.Unmarshal(valuesOut, valueBuf) if err != nil { return 0, err } @@ -1026,11 +1027,11 @@ func (m *Map) BatchUpdate(keys, values interface{}, opts *BatchOptions) (int, er if count != valuesValue.Len() { return 0, fmt.Errorf("keys and values must be the same length") } - keyPtr, err := marshalPtr(keys, count*int(m.keySize)) + keyPtr, err := marshalMapSyscallInput(keys, count*int(m.keySize)) if err != nil { return 0, err } - valuePtr, err = marshalPtr(values, count*int(m.valueSize)) + valuePtr, err = marshalMapSyscallInput(values, count*int(m.valueSize)) if err != nil { return 0, err } @@ -1068,7 +1069,7 @@ func (m *Map) BatchDelete(keys interface{}, opts *BatchOptions) (int, error) { return 0, fmt.Errorf("keys must be a slice") } count := keysValue.Len() - keyPtr, err := marshalPtr(keys, count*int(m.keySize)) + keyPtr, err := marshalMapSyscallInput(keys, count*int(m.keySize)) if err != nil { return 0, fmt.Errorf("cannot marshal keys: %v", err) } @@ -1232,16 +1233,7 @@ func (m *Map) marshalKey(data interface{}) (sys.Pointer, error) { return sys.Pointer{}, errors.New("can't use nil as key of map") } - return marshalPtr(data, int(m.keySize)) -} - -func (m *Map) unmarshalKey(data interface{}, buf []byte) error { - if buf == nil { - // This is from a makeBuffer call, nothing do do here. - return nil - } - - return unmarshalBytes(data, buf) + return marshalMapSyscallInput(data, int(m.keySize)) } func (m *Map) marshalValue(data interface{}) (sys.Pointer, error) { @@ -1264,7 +1256,7 @@ func (m *Map) marshalValue(data interface{}) (sys.Pointer, error) { buf, err = marshalProgram(value, int(m.valueSize)) default: - return marshalPtr(data, int(m.valueSize)) + return marshalMapSyscallInput(data, int(m.valueSize)) } if err != nil { @@ -1274,16 +1266,7 @@ func (m *Map) marshalValue(data interface{}) (sys.Pointer, error) { return sys.NewSlicePointer(buf), nil } -func (m *Map) unmarshalValue(value interface{}, buf []byte) error { - if buf == nil { - // This is from a makeBuffer call, nothing do do here. - return nil - } - - if m.typ.hasPerCPUValue() { - return unmarshalPerCPUValue(value, int(m.valueSize), buf) - } - +func (m *Map) unmarshalValue(value any, buf sysenc.Buffer) error { switch value := value.(type) { case **Map: if !m.typ.canStoreMap() { @@ -1330,7 +1313,7 @@ func (m *Map) unmarshalValue(value interface{}, buf []byte) error { return errors.New("require pointer to *Program") } - return unmarshalBytes(value, buf) + return buf.Unmarshal(value) } // LoadPinnedMap loads a Map from a BPF file. @@ -1352,12 +1335,11 @@ func LoadPinnedMap(fileName string, opts *LoadPinOptions) (*Map, error) { } // unmarshalMap creates a map from a map ID encoded in host endianness. -func unmarshalMap(buf []byte) (*Map, error) { - if len(buf) != 4 { - return nil, errors.New("map id requires 4 byte value") +func unmarshalMap(buf sysenc.Buffer) (*Map, error) { + var id uint32 + if err := buf.Unmarshal(&id); err != nil { + return nil, err } - - id := internal.NativeEndian.Uint32(buf) return NewMapFromID(MapID(id)) } @@ -1453,7 +1435,12 @@ func (mi *MapIterator) Next(keyOut, valueOut interface{}) bool { return false } - mi.err = mi.target.unmarshalKey(keyOut, nextKey) + if ptr, ok := keyOut.(unsafe.Pointer); ok { + copy(unsafe.Slice((*byte)(ptr), len(nextKey)), nextKey) + } else { + mi.err = sysenc.Unmarshal(keyOut, nextKey) + } + return mi.err == nil } diff --git a/marshalers.go b/marshalers.go index a568bff92..e89a12f0f 100644 --- a/marshalers.go +++ b/marshalers.go @@ -1,166 +1,53 @@ package ebpf import ( - "bytes" "encoding" - "encoding/binary" "errors" "fmt" "reflect" - "runtime" - "sync" "unsafe" "github.com/cilium/ebpf/internal" "github.com/cilium/ebpf/internal/sys" + "github.com/cilium/ebpf/internal/sysenc" ) -// marshalPtr converts an arbitrary value into a pointer suitable +// marshalMapSyscallInput converts an arbitrary value into a pointer suitable // to be passed to the kernel. // // As an optimization, it returns the original value if it is an // unsafe.Pointer. -func marshalPtr(data interface{}, length int) (sys.Pointer, error) { +func marshalMapSyscallInput(data any, length int) (sys.Pointer, error) { if ptr, ok := data.(unsafe.Pointer); ok { return sys.NewPointer(ptr), nil } - buf, err := marshalBytes(data, length) + buf, err := sysenc.Marshal(data, length) if err != nil { return sys.Pointer{}, err } - return sys.NewSlicePointer(buf), nil + return buf.Pointer(), nil } -// marshalBytes converts an arbitrary value into a byte buffer. -// -// Prefer using Map.marshalKey and Map.marshalValue if possible, since -// those have special cases that allow more types to be encoded. -// -// Returns an error if the given value isn't representable in exactly -// length bytes. -func marshalBytes(data interface{}, length int) (buf []byte, err error) { - if data == nil { - return nil, errors.New("can't marshal a nil value") - } - - switch value := data.(type) { - case encoding.BinaryMarshaler: - buf, err = value.MarshalBinary() - case string: - buf = []byte(value) - case []byte: - buf = value - case unsafe.Pointer: - err = errors.New("can't marshal from unsafe.Pointer") - case Map, *Map, Program, *Program: - err = fmt.Errorf("can't marshal %T", value) - default: - wr := internal.NewBuffer(make([]byte, 0, length)) - defer internal.PutBuffer(wr) - - err = binary.Write(wr, internal.NativeEndian, value) - if err != nil { - err = fmt.Errorf("encoding %T: %v", value, err) - } - buf = wr.Bytes() - } - if err != nil { - return nil, err - } - - if len(buf) != length { - return nil, fmt.Errorf("%T doesn't marshal to %d bytes", data, length) - } - return buf, nil -} - -func makeBuffer(dst interface{}, length int) (sys.Pointer, []byte) { +func makeMapSyscallOutput(dst any, length int) sysenc.Buffer { if ptr, ok := dst.(unsafe.Pointer); ok { - return sys.NewPointer(ptr), nil + return sysenc.UnsafeBuffer(ptr) } - buf := make([]byte, length) - return sys.NewSlicePointer(buf), buf -} - -var bytesReaderPool = sync.Pool{ - New: func() interface{} { - return new(bytes.Reader) - }, -} - -// unmarshalBytes converts a byte buffer into an arbitrary value. -// -// Prefer using Map.unmarshalKey and Map.unmarshalValue if possible, since -// those have special cases that allow more types to be encoded. -// -// The common int32 and int64 types are directly handled to avoid -// unnecessary heap allocations as happening in the default case. -func unmarshalBytes(data interface{}, buf []byte) error { - switch value := data.(type) { - case unsafe.Pointer: - dst := unsafe.Slice((*byte)(value), len(buf)) - copy(dst, buf) - runtime.KeepAlive(value) - return nil - case Map, *Map, Program, *Program: - return fmt.Errorf("can't unmarshal into %T", value) - case encoding.BinaryUnmarshaler: - return value.UnmarshalBinary(buf) - case *string: - *value = string(buf) - return nil - case *[]byte: - *value = buf - return nil - case *int32: - if len(buf) < 4 { - return errors.New("int32 requires 4 bytes") - } - *value = int32(internal.NativeEndian.Uint32(buf)) - return nil - case *uint32: - if len(buf) < 4 { - return errors.New("uint32 requires 4 bytes") - } - *value = internal.NativeEndian.Uint32(buf) - return nil - case *int64: - if len(buf) < 8 { - return errors.New("int64 requires 8 bytes") - } - *value = int64(internal.NativeEndian.Uint64(buf)) - return nil - case *uint64: - if len(buf) < 8 { - return errors.New("uint64 requires 8 bytes") - } - *value = internal.NativeEndian.Uint64(buf) - return nil - case string: - return errors.New("require pointer to string") - case []byte: - return errors.New("require pointer to []byte") - default: - rd := bytesReaderPool.Get().(*bytes.Reader) - rd.Reset(buf) - defer bytesReaderPool.Put(rd) - if err := binary.Read(rd, internal.NativeEndian, value); err != nil { - return fmt.Errorf("decoding %T: %v", value, err) - } - return nil + _, ok := dst.(encoding.BinaryUnmarshaler) + if ok { + return sysenc.SyscallOutput(nil, length) } + + return sysenc.SyscallOutput(dst, length) } // marshalPerCPUValue encodes a slice containing one value per // possible CPU into a buffer of bytes. // // Values are initialized to zero if the slice has less elements than CPUs. -// -// slice must have a type like []elementType. -func marshalPerCPUValue(slice interface{}, elemLength int) (sys.Pointer, error) { +func marshalPerCPUValue(slice any, elemLength int) (sys.Pointer, error) { sliceType := reflect.TypeOf(slice) if sliceType.Kind() != reflect.Slice { return sys.Pointer{}, errors.New("per-CPU value requires slice") @@ -182,13 +69,13 @@ func marshalPerCPUValue(slice interface{}, elemLength int) (sys.Pointer, error) for i := 0; i < sliceLen; i++ { elem := sliceValue.Index(i).Interface() - elemBytes, err := marshalBytes(elem, elemLength) + elemBytes, err := sysenc.Marshal(elem, elemLength) if err != nil { return sys.Pointer{}, err } offset := i * alignedElemLength - copy(buf[offset:offset+elemLength], elemBytes) + elemBytes.CopyTo(buf[offset : offset+elemLength]) } return sys.NewSlicePointer(buf), nil @@ -197,8 +84,8 @@ func marshalPerCPUValue(slice interface{}, elemLength int) (sys.Pointer, error) // unmarshalPerCPUValue decodes a buffer into a slice containing one value per // possible CPU. // -// valueOut must have a type like *[]elementType -func unmarshalPerCPUValue(slicePtr interface{}, elemLength int, buf []byte) error { +// slicePtr must be a pointer to a slice. +func unmarshalPerCPUValue(slicePtr any, elemLength int, buf []byte) error { slicePtrType := reflect.TypeOf(slicePtr) if slicePtrType.Kind() != reflect.Ptr || slicePtrType.Elem().Kind() != reflect.Slice { return fmt.Errorf("per-cpu value requires pointer to slice") @@ -218,12 +105,9 @@ func unmarshalPerCPUValue(slicePtr interface{}, elemLength int, buf []byte) erro sliceElemType = sliceElemType.Elem() } - step := len(buf) / possibleCPUs - if step < elemLength { - return fmt.Errorf("per-cpu element length is larger than available data") - } + stride := internal.Align(elemLength, 8) for i := 0; i < possibleCPUs; i++ { - var elem interface{} + var elem any if sliceElemIsPointer { newElem := reflect.New(sliceElemType) slice.Index(i).Set(newElem) @@ -232,16 +116,12 @@ func unmarshalPerCPUValue(slicePtr interface{}, elemLength int, buf []byte) erro elem = slice.Index(i).Addr().Interface() } - // Make a copy, since unmarshal can hold on to itemBytes - elemBytes := make([]byte, elemLength) - copy(elemBytes, buf[:elemLength]) - - err := unmarshalBytes(elem, elemBytes) + err := sysenc.Unmarshal(elem, buf[:elemLength]) if err != nil { return fmt.Errorf("cpu %d: %w", i, err) } - buf = buf[step:] + buf = buf[stride:] } reflect.ValueOf(slicePtr).Elem().Set(slice) diff --git a/prog.go b/prog.go index 70aaef553..53d45bebe 100644 --- a/prog.go +++ b/prog.go @@ -16,6 +16,7 @@ import ( "github.com/cilium/ebpf/btf" "github.com/cilium/ebpf/internal" "github.com/cilium/ebpf/internal/sys" + "github.com/cilium/ebpf/internal/sysenc" "github.com/cilium/ebpf/internal/unix" ) @@ -763,14 +764,14 @@ retry: return attr.Retval, total, nil } -func unmarshalProgram(buf []byte) (*Program, error) { - if len(buf) != 4 { - return nil, errors.New("program id requires 4 byte value") +func unmarshalProgram(buf sysenc.Buffer) (*Program, error) { + var id uint32 + if err := buf.Unmarshal(&id); err != nil { + return nil, err } // Looking up an entry in a nested map or prog array returns an id, // not an fd. - id := internal.NativeEndian.Uint32(buf) return NewProgramFromID(ProgramID(id)) } diff --git a/prog_test.go b/prog_test.go index d4da0937a..78a75238c 100644 --- a/prog_test.go +++ b/prog_test.go @@ -701,7 +701,7 @@ func TestProgramRejectIncorrectByteOrder(t *testing.T) { spec := socketFilterSpec.Copy() spec.ByteOrder = binary.BigEndian - if internal.NativeEndian == binary.BigEndian { + if spec.ByteOrder == internal.NativeEndian { spec.ByteOrder = binary.LittleEndian } diff --git a/syscalls.go b/syscalls.go index e3b3f3c70..cdf1fcf2e 100644 --- a/syscalls.go +++ b/syscalls.go @@ -225,8 +225,8 @@ var haveBatchAPI = internal.NewFeatureTest("map batch api", "5.6", func() error keys := []uint32{1, 2} values := []uint32{3, 4} - kp, _ := marshalPtr(keys, 8) - vp, _ := marshalPtr(values, 8) + kp, _ := marshalMapSyscallInput(keys, 8) + vp, _ := marshalMapSyscallInput(values, 8) err = sys.MapUpdateBatch(&sys.MapUpdateBatchAttr{ MapFd: fd.Uint(),