diff --git a/btf/btf.go b/btf/btf.go index a27dcd16a..1ff2ce7e4 100644 --- a/btf/btf.go +++ b/btf/btf.go @@ -2,7 +2,6 @@ package btf import ( "bufio" - "bytes" "debug/elf" "encoding/binary" "errors" @@ -31,11 +30,9 @@ var ( // ID represents the unique ID of a BTF object. type ID = sys.BTFID -// Spec represents decoded BTF. +// Spec allows querying a set of Types and loading the set into the +// kernel. type Spec struct { - // Data from .BTF. - strings *stringTable - // All types contained by the spec, not including types from the base in // case the spec was parsed from split BTF. types []Type @@ -43,10 +40,17 @@ type Spec struct { // Type IDs indexed by type. typeIDs map[Type]TypeID + // The last allocated type ID. + lastTypeID TypeID + // Types indexed by essential name. // Includes all struct flavors and types with the same name. namedTypes map[essentialName][]Type + // String table from ELF, may be nil. + strings *stringTable + + // Byte order of the ELF we decoded the spec from, may be nil. byteOrder binary.ByteOrder } @@ -76,6 +80,18 @@ func (h *btfHeader) stringStart() int64 { return int64(h.HdrLen + h.StringOff) } +// NewSpec creates a Spec containing only Void. +func NewSpec() *Spec { + return &Spec{ + []Type{(*Void)(nil)}, + map[Type]TypeID{(*Void)(nil): 0}, + 0, + make(map[essentialName][]Type), + nil, + nil, + } +} + // LoadSpec opens file and calls LoadSpecFromReader on it. func LoadSpec(file string) (*Spec, error) { fh, err := os.Open(file) @@ -214,7 +230,6 @@ func loadSpecFromELF(file *internal.SafeELFFile) (*Spec, error) { func loadRawSpec(btf io.ReaderAt, bo binary.ByteOrder, baseTypes types, baseStrings *stringTable) (*Spec, error) { - rawTypes, rawStrings, err := parseBTF(btf, bo, baseStrings) if err != nil { return nil, err @@ -225,18 +240,19 @@ func loadRawSpec(btf io.ReaderAt, bo binary.ByteOrder, return nil, err } - typeIDs, typesByName := indexTypes(types, TypeID(len(baseTypes))) + typeIDs, typesByName, lastTypeID := indexTypes(types, TypeID(len(baseTypes))) return &Spec{ namedTypes: typesByName, typeIDs: typeIDs, types: types, + lastTypeID: lastTypeID, strings: rawStrings, byteOrder: bo, }, nil } -func indexTypes(types []Type, typeIDOffset TypeID) (map[Type]TypeID, map[essentialName][]Type) { +func indexTypes(types []Type, typeIDOffset TypeID) (map[Type]TypeID, map[essentialName][]Type, TypeID) { namedTypes := 0 for _, typ := range types { if typ.TypeName() != "" { @@ -250,14 +266,16 @@ func indexTypes(types []Type, typeIDOffset TypeID) (map[Type]TypeID, map[essenti typeIDs := make(map[Type]TypeID, len(types)) typesByName := make(map[essentialName][]Type, namedTypes) + var lastTypeID TypeID for i, typ := range types { if name := newEssentialName(typ.TypeName()); name != "" { typesByName[name] = append(typesByName[name], typ) } - typeIDs[typ] = TypeID(i) + typeIDOffset + lastTypeID = TypeID(i) + typeIDOffset + typeIDs[typ] = lastTypeID } - return typeIDs, typesByName + return typeIDs, typesByName, lastTypeID } // LoadKernelSpec returns the current kernel's BTF information. @@ -469,15 +487,15 @@ func fixupDatasec(types []Type, sectionSizes map[string]uint32, offsets map[symb // Copy creates a copy of Spec. func (s *Spec) Copy() *Spec { types := copyTypes(s.types, nil) - - typeIDs, typesByName := indexTypes(types, s.firstTypeID()) + typeIDs, typesByName, lastTypeID := indexTypes(types, s.firstTypeID()) // NB: Other parts of spec are not copied since they are immutable. return &Spec{ - s.strings, types, typeIDs, + lastTypeID, typesByName, + s.strings, s.byteOrder, } } @@ -492,6 +510,36 @@ func (sw sliceWriter) Write(p []byte) (int, error) { return copy(sw, p), nil } +// Add a Type to the Spec, making it queryable via [TypeByName], etc. +// +// Adding the identical Type multiple times is valid and will return a stable ID. +// +// See [Type] for details on identity. +func (s *Spec) Add(typ Type) (TypeID, error) { + if typ == nil { + return 0, fmt.Errorf("can't add nil Type") + } + + if id, err := s.TypeID(typ); err == nil { + return id, nil + } + + id := s.lastTypeID + 1 + if id < s.lastTypeID { + return 0, fmt.Errorf("type ID overflow") + } + + s.typeIDs[typ] = id + s.types = append(s.types, typ) + s.lastTypeID = id + + if name := newEssentialName(typ.TypeName()); name != "" { + s.namedTypes[name] = append(s.namedTypes[name], typ) + } + + return id, nil +} + // TypeByID returns the BTF Type with the given type ID. // // Returns an error wrapping ErrNotFound if a Type with the given ID @@ -638,6 +686,10 @@ func (s *Spec) firstTypeID() TypeID { // Types from base are used to resolve references in the split BTF. // The returned Spec only contains types from the split BTF, not from the base. func LoadSplitSpecFromReader(r io.ReaderAt, base *Spec) (*Spec, error) { + if base.strings == nil { + return nil, fmt.Errorf("parse split BTF: base must be loaded from an ELF") + } + return loadRawSpec(r, internal.NativeEndian, base.types, base.strings) } @@ -665,55 +717,15 @@ func (iter *TypesIterator) Next() bool { return true } -func marshalBTF(types interface{}, strings []byte, bo binary.ByteOrder) []byte { - const minHeaderLength = 24 - - typesLen := uint32(binary.Size(types)) - header := btfHeader{ - Magic: btfMagic, - Version: 1, - HdrLen: minHeaderLength, - TypeOff: 0, - TypeLen: typesLen, - StringOff: typesLen, - StringLen: uint32(len(strings)), - } - - buf := new(bytes.Buffer) - _ = binary.Write(buf, bo, &header) - _ = binary.Write(buf, bo, types) - buf.Write(strings) - - return buf.Bytes() -} - // haveBTF attempts to load a BTF blob containing an Int. It should pass on any // kernel that supports BPF_BTF_LOAD. var haveBTF = internal.NewFeatureTest("BTF", "4.18", func() error { - var ( - types struct { - Integer btfType - btfInt - } - strings = []byte{0} - ) - types.Integer.SetKind(kindInt) // 0-length anonymous integer - - btf := marshalBTF(&types, strings, internal.NativeEndian) - - fd, err := sys.BtfLoad(&sys.BtfLoadAttr{ - Btf: sys.NewSlicePointer(btf), - BtfSize: uint32(len(btf)), - }) + // 0-length anonymous integer + err := probeBTF(&Int{}) if errors.Is(err, unix.EINVAL) || errors.Is(err, unix.EPERM) { return internal.ErrNotSupported } - if err != nil { - return err - } - - fd.Close() - return nil + return err }) // haveMapBTF attempts to load a minimal BTF blob containing a Var. It is @@ -724,37 +736,18 @@ var haveMapBTF = internal.NewFeatureTest("Map BTF (Var/Datasec)", "5.2", func() return err } - var ( - types struct { - Integer btfType - Var btfType - btfVariable - } - strings = []byte{0, 'a', 0} - ) - - types.Integer.SetKind(kindPointer) - types.Var.NameOff = 1 - types.Var.SetKind(kindVar) - types.Var.SizeType = 1 - - btf := marshalBTF(&types, strings, internal.NativeEndian) + v := &Var{ + Name: "a", + Type: &Pointer{(*Void)(nil)}, + } - fd, err := sys.BtfLoad(&sys.BtfLoadAttr{ - Btf: sys.NewSlicePointer(btf), - BtfSize: uint32(len(btf)), - }) + err := probeBTF(v) if errors.Is(err, unix.EINVAL) || errors.Is(err, unix.EPERM) { // Treat both EINVAL and EPERM as not supported: creating the map may still // succeed without Btf* attrs. return internal.ErrNotSupported } - if err != nil { - return err - } - - fd.Close() - return nil + return err }) // haveProgBTF attempts to load a BTF blob containing a Func and FuncProto. It @@ -765,34 +758,16 @@ var haveProgBTF = internal.NewFeatureTest("Program BTF (func/line_info)", "5.0", return err } - var ( - types struct { - FuncProto btfType - Func btfType - } - strings = []byte{0, 'a', 0} - ) - - types.FuncProto.SetKind(kindFuncProto) - types.Func.SetKind(kindFunc) - types.Func.SizeType = 1 // aka FuncProto - types.Func.NameOff = 1 - - btf := marshalBTF(&types, strings, internal.NativeEndian) + fn := &Func{ + Name: "a", + Type: &FuncProto{Return: (*Void)(nil)}, + } - fd, err := sys.BtfLoad(&sys.BtfLoadAttr{ - Btf: sys.NewSlicePointer(btf), - BtfSize: uint32(len(btf)), - }) + err := probeBTF(fn) if errors.Is(err, unix.EINVAL) || errors.Is(err, unix.EPERM) { return internal.ErrNotSupported } - if err != nil { - return err - } - - fd.Close() - return nil + return err }) var haveFuncLinkage = internal.NewFeatureTest("BTF func linkage", "5.6", func() error { @@ -800,33 +775,35 @@ var haveFuncLinkage = internal.NewFeatureTest("BTF func linkage", "5.6", func() return err } - var ( - types struct { - FuncProto btfType - Func btfType - } - strings = []byte{0, 'a', 0} - ) - - types.FuncProto.SetKind(kindFuncProto) - types.Func.SetKind(kindFunc) - types.Func.SizeType = 1 // aka FuncProto - types.Func.NameOff = 1 - types.Func.SetLinkage(GlobalFunc) - - btf := marshalBTF(&types, strings, internal.NativeEndian) + fn := &Func{ + Name: "a", + Type: &FuncProto{Return: (*Void)(nil)}, + Linkage: GlobalFunc, + } - fd, err := sys.BtfLoad(&sys.BtfLoadAttr{ - Btf: sys.NewSlicePointer(btf), - BtfSize: uint32(len(btf)), - }) + err := probeBTF(fn) if errors.Is(err, unix.EINVAL) { return internal.ErrNotSupported } - if err != nil { + return err +}) + +func probeBTF(typ Type) error { + buf := getBuffer() + defer putBuffer(buf) + + if err := marshalTypes(buf, []Type{&Void{}, typ}, nil, nil); err != nil { return err } - fd.Close() - return nil -}) + fd, err := sys.BtfLoad(&sys.BtfLoadAttr{ + Btf: sys.NewSlicePointer(buf.Bytes()), + BtfSize: uint32(buf.Len()), + }) + + if err == nil { + fd.Close() + } + + return err +} diff --git a/btf/btf_test.go b/btf/btf_test.go index d681a16d7..256cbcd8f 100644 --- a/btf/btf_test.go +++ b/btf/btf_test.go @@ -218,6 +218,46 @@ func TestTypeByName(t *testing.T) { } } +func TestSpecAdd(t *testing.T) { + i := &Int{ + Name: "foo", + Size: 2, + Encoding: Signed | Char, + } + pi := &Pointer{i} + + s := NewSpec() + id, err := s.Add(pi) + qt.Assert(t, err, qt.IsNil) + qt.Assert(t, id, qt.Equals, TypeID(1), qt.Commentf("First non-void type doesn't get id 1")) + + id, err = s.Add(pi) + qt.Assert(t, err, qt.IsNil) + qt.Assert(t, id, qt.Equals, TypeID(1)) + + _, err = s.TypeID(i) + qt.Assert(t, err, qt.IsNotNil, qt.Commentf("Children mustn't be added")) + + id, err = s.Add(i) + qt.Assert(t, err, qt.IsNil) + qt.Assert(t, id, qt.Equals, TypeID(2), qt.Commentf("Second type doesn't get id 2")) + + id, err = s.Add(i) + qt.Assert(t, err, qt.IsNil) + qt.Assert(t, id, qt.Equals, TypeID(2), qt.Commentf("Adding a type twice returns different ids")) + + typ, err := s.AnyTypeByName("foo") + qt.Assert(t, err, qt.IsNil, qt.Commentf("Add doesn't make named type queryable")) + qt.Assert(t, typ, qt.Equals, i) + + id, err = s.Add(&Typedef{"baz", i}) + qt.Assert(t, err, qt.IsNil) + qt.Assert(t, id, qt.Equals, TypeID(3)) + + _, err = s.AnyTypeByName("baz") + qt.Assert(t, err, qt.IsNil) +} + func BenchmarkParseVmlinux(b *testing.B) { rd := vmlinuxTestdataReader(b) b.ReportAllocs() @@ -329,8 +369,12 @@ func TestLoadSpecFromElf(t *testing.T) { } func TestVerifierError(t *testing.T) { - btf, _ := newEncoder(kernelEncoderOptions, nil).Encode() - _, err := newHandleFromRawBTF(btf) + var buf bytes.Buffer + if err := marshalTypes(&buf, []Type{&Void{}}, nil, nil); err != nil { + t.Fatal(err) + } + + _, err := newHandleFromRawBTF(buf.Bytes()) testutils.SkipIfNotSupported(t, err) var ve *internal.VerifierError if !errors.As(err, &ve) { diff --git a/btf/btf_types.go b/btf/btf_types.go index dc568a90e..a253b7c9b 100644 --- a/btf/btf_types.go +++ b/btf/btf_types.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "io" + "unsafe" ) //go:generate stringer -linecomment -output=btf_types_string.go -type=FuncLinkage,VarLinkage,btfKind @@ -193,13 +194,22 @@ func (bt *btfType) SetSize(size uint32) { bt.SizeType = size } +func (bt *btfType) Marshal(w io.Writer, bo binary.ByteOrder) error { + buf := make([]byte, unsafe.Sizeof(*bt)) + bo.PutUint32(buf[0:], bt.NameOff) + bo.PutUint32(buf[4:], bt.Info) + bo.PutUint32(buf[8:], bt.SizeType) + _, err := w.Write(buf) + return err +} + type rawType struct { btfType data interface{} } func (rt *rawType) Marshal(w io.Writer, bo binary.ByteOrder) error { - if err := binary.Write(w, bo, &rt.btfType); err != nil { + if err := rt.btfType.Marshal(w, bo); err != nil { return err } diff --git a/btf/ext_info.go b/btf/ext_info.go index 36f3b7baf..dba36d6be 100644 --- a/btf/ext_info.go +++ b/btf/ext_info.go @@ -8,7 +8,6 @@ import ( "io" "math" "sort" - "sync" "github.com/cilium/ebpf/asm" "github.com/cilium/ebpf/internal" @@ -131,12 +130,6 @@ func (ei *ExtInfos) Assign(insns asm.Instructions, section string) { } } -var nativeEncoderPool = sync.Pool{ - New: func() any { - return newEncoder(kernelEncoderOptions, nil) - }, -} - // MarshalExtInfos encodes function and line info embedded in insns into kernel // wire format. // @@ -157,15 +150,11 @@ func MarshalExtInfos(insns asm.Instructions) (_ *Handle, funcInfos, lineInfos [] } } - // Avoid allocating encoder, etc. if there is no BTF at all. return nil, nil, nil, nil marshal: - enc := nativeEncoderPool.Get().(*encoder) - defer nativeEncoderPool.Put(enc) - - enc.Reset() - + stb := newStringTableBuilder(0) + spec := NewSpec() var fiBuf, liBuf bytes.Buffer for { if fn := FuncMetadata(iter.Ins); fn != nil { @@ -173,7 +162,7 @@ marshal: fn: fn, offset: iter.Offset, } - if err := fi.marshal(&fiBuf, enc); err != nil { + if err := fi.marshal(&fiBuf, spec); err != nil { return nil, nil, nil, fmt.Errorf("write func info: %w", err) } } @@ -183,7 +172,7 @@ marshal: line: line, offset: iter.Offset, } - if err := li.marshal(&liBuf, enc.strings); err != nil { + if err := li.marshal(&liBuf, stb); err != nil { return nil, nil, nil, fmt.Errorf("write line info: %w", err) } } @@ -193,12 +182,14 @@ marshal: } } - btf, err := enc.Encode() - if err != nil { - return nil, nil, nil, err + buf := getBuffer() + defer putBuffer(buf) + + if err := marshalTypes(buf, spec.types, stb, kernelMarshalOptions); err != nil { + return nil, nil, nil, fmt.Errorf("marshal BTF: %w", err) } - handle, err := newHandleFromRawBTF(btf) + handle, err := newHandleFromRawBTF(buf.Bytes()) return handle, fiBuf.Bytes(), liBuf.Bytes(), err } @@ -392,8 +383,8 @@ func newFuncInfos(bfis []bpfFuncInfo, ts types) ([]funcInfo, error) { } // marshal into the BTF wire format. -func (fi *funcInfo) marshal(w *bytes.Buffer, enc *encoder) error { - id, err := enc.Add(fi.fn) +func (fi *funcInfo) marshal(w *bytes.Buffer, spec *Spec) error { + id, err := spec.Add(fi.fn) if err != nil { return err } diff --git a/btf/handle.go b/btf/handle.go index 9a864d177..75102dba5 100644 --- a/btf/handle.go +++ b/btf/handle.go @@ -30,21 +30,27 @@ func NewHandle(spec *Spec) (*Handle, error) { return nil, fmt.Errorf("can't load %s BTF on %s", spec.byteOrder, internal.NativeEndian) } - enc := newEncoder(kernelEncoderOptions, newStringTableBuilderFromTable(spec.strings)) + if spec.firstTypeID() != 0 { + return nil, fmt.Errorf("split BTF can't be loaded into the kernel") + } - for _, typ := range spec.types { - _, err := enc.Add(typ) - if err != nil { - return nil, fmt.Errorf("add %s: %w", typ, err) - } + buf := getBuffer() + defer putBuffer(buf) + + var stb *stringTableBuilder + if spec.strings != nil { + // Use the ELF string table as an estimate of the final + // string table size. We don't use the ELF string + // table since the types may have been changed in the meantime. + stb = newStringTableBuilder(spec.strings.Num()) } - btf, err := enc.Encode() + err := marshalTypes(buf, spec.types, stb, kernelMarshalOptions) if err != nil { return nil, fmt.Errorf("marshal BTF: %w", err) } - return newHandleFromRawBTF(btf) + return newHandleFromRawBTF(buf.Bytes()) } func newHandleFromRawBTF(btf []byte) (*Handle, error) { diff --git a/btf/marshal.go b/btf/marshal.go index 4ae479bd9..68cd05393 100644 --- a/btf/marshal.go +++ b/btf/marshal.go @@ -6,141 +6,114 @@ import ( "errors" "fmt" "math" + "sync" "github.com/cilium/ebpf/internal" ) -type encoderOptions struct { - ByteOrder binary.ByteOrder +type marshalOptions struct { // Remove function linkage information for compatibility with <5.6 kernels. StripFuncLinkage bool } -// kernelEncoderOptions will generate BTF suitable for the current kernel. -var kernelEncoderOptions encoderOptions - -func init() { - kernelEncoderOptions = encoderOptions{ - ByteOrder: internal.NativeEndian, - StripFuncLinkage: haveFuncLinkage() != nil, - } +// kernelMarshalOptions will generate BTF suitable for the current kernel. +var kernelMarshalOptions = &marshalOptions{ + StripFuncLinkage: haveFuncLinkage() != nil, } // encoder turns Types into raw BTF. type encoder struct { - opts encoderOptions - - buf *bytes.Buffer - strings *stringTableBuilder - allocatedIDs map[Type]TypeID - nextID TypeID - // Temporary storage for Add. - pending internal.Deque[Type] - // Temporary storage for deflateType. - raw rawType + marshalOptions + + byteOrder binary.ByteOrder + pending internal.Deque[Type] + buf *bytes.Buffer + strings *stringTableBuilder + ids map[Type]TypeID + lastID TypeID } -// newEncoder returns a new builder for the given byte order. -// -// See [KernelEncoderOptions] to build BTF for the current system. -func newEncoder(opts encoderOptions, strings *stringTableBuilder) *encoder { - enc := &encoder{ - opts: opts, - buf: bytes.NewBuffer(make([]byte, btfHeaderLen)), - } - enc.reset(strings) - return enc -} +var emptyBTFHeader = make([]byte, btfHeaderLen) -// Reset internal state to be able to reuse the Encoder. -func (e *encoder) Reset() { - e.reset(nil) +var bufferPool = sync.Pool{ + New: func() any { + return bytes.NewBuffer(make([]byte, btfHeaderLen+128)) + }, } -func (e *encoder) reset(strings *stringTableBuilder) { - if strings == nil { - strings = newStringTableBuilder() - } +func getBuffer() *bytes.Buffer { + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() + return buf +} - e.buf.Truncate(btfHeaderLen) - e.strings = strings - e.allocatedIDs = make(map[Type]TypeID) - e.nextID = 1 +func putBuffer(buf *bytes.Buffer) { + bufferPool.Put(buf) } -// Add a Type. +// marshalTypes encodes a slice of types into BTF wire format. +// +// types are guaranteed to be written in the order they are passed to this +// function. The first type must always be Void. // -// Adding the same Type multiple times is valid and will return a stable ID. +// Doesn't support encoding split BTF since it's not possible to load +// that into the kernel and we don't have a use case for writing BTF +// out again. // -// Calling the method has undefined behaviour if it previously returned an error. -func (e *encoder) Add(typ Type) (TypeID, error) { - if typ == nil { - return 0, errors.New("cannot Add a nil Type") +// w should be retrieved from bufferPool. opts may be nil. +func marshalTypes(w *bytes.Buffer, types []Type, stb *stringTableBuilder, opts *marshalOptions) error { + if len(types) < 1 { + return errors.New("types must contain at least Void") } - hasID := func(t Type) (skip bool) { - _, isVoid := t.(*Void) - _, alreadyEncoded := e.allocatedIDs[t] - return isVoid || alreadyEncoded + if _, ok := types[0].(*Void); !ok { + return fmt.Errorf("first type is %s, not Void", types[0]) } + types = types[1:] - e.pending.Reset() - - allocateID := func(typ Type) { - e.pending.Push(typ) - e.allocatedIDs[typ] = e.nextID - e.nextID++ + if stb == nil { + stb = newStringTableBuilder(0) } - iter := postorderTraversal(typ, hasID) - for iter.Next() { - if hasID(iter.Type) { - // This type is part of a cycle and we've already deflated it. - continue - } - - // Allocate an ID for the next type. - allocateID(iter.Type) - - for !e.pending.Empty() { - t := e.pending.Shift() + e := encoder{ + byteOrder: internal.NativeEndian, + buf: w, + strings: stb, + ids: make(map[Type]TypeID, len(types)), + } - // Ensure that all direct descendants have been allocated an ID - // before calling deflateType. - walkType(t, func(child *Type) { - if !hasID(*child) { - // t refers to a type which hasn't been allocated an ID - // yet, which only happens for circular types. - allocateID(*child) - } - }) + if opts != nil { + e.marshalOptions = *opts + } - if err := e.deflateType(t); err != nil { - return 0, fmt.Errorf("deflate %s: %w", t, err) - } + // Ensure that passed types are marshaled in the exact order they were + // passed. + e.pending.Grow(len(types)) + for _, typ := range types { + if err := e.allocateID(typ); err != nil { + return err } } - return e.allocatedIDs[typ], nil -} - -// Encode the raw BTF blob. -// -// The returned slice is valid until the next call to Add. -func (e *encoder) Encode() ([]byte, error) { - length := e.buf.Len() + // Reserve space for the BTF header. + _, _ = e.buf.Write(emptyBTFHeader) - // Truncate the string table on return to allow adding more types. - defer e.buf.Truncate(length) + if err := e.deflatePending(); err != nil { + return err + } + length := e.buf.Len() typeLen := uint32(length - btfHeaderLen) // Reserve space for the string table. stringLen := e.strings.Length() e.buf.Grow(stringLen) + buf := e.strings.AppendEncoded(e.buf.Bytes()) - buf := e.buf.Bytes()[:length+stringLen] - e.strings.MarshalBuffer(buf[length:]) + // Add string table to the unread portion of the buffer, otherwise + // it isn't return by Bytes(). + // The copy is optimized out since src == dst. + _, _ = e.buf.Write(buf[length:]) // Fill out the header, and write it out. header := &btfHeader{ @@ -154,23 +127,102 @@ func (e *encoder) Encode() ([]byte, error) { StringLen: uint32(stringLen), } - err := binary.Write(sliceWriter(buf[:btfHeaderLen]), e.opts.ByteOrder, header) + err := binary.Write(sliceWriter(buf[:btfHeaderLen]), e.byteOrder, header) if err != nil { - return nil, fmt.Errorf("can't write header: %v", err) + return fmt.Errorf("write header: %v", err) + } + + return nil +} + +func (e *encoder) allocateID(typ Type) error { + id := e.lastID + 1 + if id < e.lastID { + return errors.New("type ID overflow") + } + + e.pending.Push(typ) + e.ids[typ] = id + e.lastID = id + return nil +} + +// id returns the ID for the given type or panics with an error. +func (e *encoder) id(typ Type) TypeID { + if _, ok := typ.(*Void); ok { + return 0 + } + + id, ok := e.ids[typ] + if !ok { + panic(fmt.Errorf("no ID for type %v", typ)) + } + + return id +} + +func (e *encoder) deflatePending() error { + // Declare root outside of the loop to avoid repeated heap allocations. + var root Type + skip := func(t Type) (skip bool) { + if t == root { + // Force descending into the current root type even if it already + // has an ID. Otherwise we miss children of types that have their + // ID pre-allocated in marshalTypes. + return false + } + + _, isVoid := t.(*Void) + _, alreadyEncoded := e.ids[t] + return isVoid || alreadyEncoded + } + + for !e.pending.Empty() { + root = e.pending.Shift() + + // Allocate IDs for all children of typ, including transitive dependencies. + iter := postorderTraversal(root, skip) + for iter.Next() { + if iter.Type == root { + // The iterator yields root at the end, do not allocate another ID. + break + } + + if err := e.allocateID(iter.Type); err != nil { + return err + } + } + + if err := e.deflateType(root); err != nil { + id := e.ids[root] + return fmt.Errorf("deflate %v with ID %d: %w", root, id, err) + } } - return buf, nil + return nil } func (e *encoder) deflateType(typ Type) (err error) { - raw := &e.raw - *raw = rawType{} + defer func() { + if r := recover(); r != nil { + var ok bool + err, ok = r.(error) + if !ok { + panic(r) + } + } + }() + + var raw rawType raw.NameOff, err = e.strings.Add(typ.TypeName()) if err != nil { return err } switch v := typ.(type) { + case *Void: + return errors.New("Void is implicit in BTF wire format") + case *Int: raw.SetKind(kindInt) raw.SetSize(v.Size) @@ -184,13 +236,13 @@ func (e *encoder) deflateType(typ Type) (err error) { case *Pointer: raw.SetKind(kindPointer) - raw.SetType(e.allocatedIDs[v.Target]) + raw.SetType(e.id(v.Target)) case *Array: raw.SetKind(kindArray) raw.data = &btfArray{ - e.allocatedIDs[v.Type], - e.allocatedIDs[v.Index], + e.id(v.Type), + e.id(v.Index), v.Nelems, } @@ -223,36 +275,36 @@ func (e *encoder) deflateType(typ Type) (err error) { case *Typedef: raw.SetKind(kindTypedef) - raw.SetType(e.allocatedIDs[v.Type]) + raw.SetType(e.id(v.Type)) case *Volatile: raw.SetKind(kindVolatile) - raw.SetType(e.allocatedIDs[v.Type]) + raw.SetType(e.id(v.Type)) case *Const: raw.SetKind(kindConst) - raw.SetType(e.allocatedIDs[v.Type]) + raw.SetType(e.id(v.Type)) case *Restrict: raw.SetKind(kindRestrict) - raw.SetType(e.allocatedIDs[v.Type]) + raw.SetType(e.id(v.Type)) case *Func: raw.SetKind(kindFunc) - raw.SetType(e.allocatedIDs[v.Type]) - if !e.opts.StripFuncLinkage { + raw.SetType(e.id(v.Type)) + if !e.StripFuncLinkage { raw.SetLinkage(v.Linkage) } case *FuncProto: raw.SetKind(kindFuncProto) - raw.SetType(e.allocatedIDs[v.Return]) + raw.SetType(e.id(v.Return)) raw.SetVlen(len(v.Params)) raw.data, err = e.deflateFuncParams(v.Params) case *Var: raw.SetKind(kindVar) - raw.SetType(e.allocatedIDs[v.Type]) + raw.SetType(e.id(v.Type)) raw.data = btfVariable{uint32(v.Linkage)} case *Datasec: @@ -281,7 +333,7 @@ func (e *encoder) deflateType(typ Type) (err error) { return err } - return raw.Marshal(e.buf, e.opts.ByteOrder) + return raw.Marshal(e.buf, e.byteOrder) } func (e *encoder) convertMembers(header *btfType, members []Member) ([]btfMember, error) { @@ -302,7 +354,7 @@ func (e *encoder) convertMembers(header *btfType, members []Member) ([]btfMember bms = append(bms, btfMember{ nameOff, - e.allocatedIDs[member.Type], + e.id(member.Type), uint32(offset), }) } @@ -361,7 +413,7 @@ func (e *encoder) deflateFuncParams(params []FuncParam) ([]btfParam, error) { bps = append(bps, btfParam{ nameOff, - e.allocatedIDs[param.Type], + e.id(param.Type), }) } return bps, nil @@ -371,7 +423,7 @@ func (e *encoder) deflateVarSecinfos(vars []VarSecinfo) []btfVarSecinfo { vsis := make([]btfVarSecinfo, 0, len(vars)) for _, v := range vars { vsis = append(vsis, btfVarSecinfo{ - e.allocatedIDs[v.Type], + e.id(v.Type), v.Offset, v.Size, }) @@ -383,33 +435,24 @@ func (e *encoder) deflateVarSecinfos(vars []VarSecinfo) []btfVarSecinfo { // // The function is intended for the use of the ebpf package and may be removed // at any point in time. -func MarshalMapKV(key, value Type) (_ *Handle, keyID, valueID TypeID, _ error) { - enc := nativeEncoderPool.Get().(*encoder) - defer nativeEncoderPool.Put(enc) +func MarshalMapKV(key, value Type) (_ *Handle, keyID, valueID TypeID, err error) { + spec := NewSpec() - enc.Reset() - - var err error if key != nil { - keyID, err = enc.Add(key) + keyID, err = spec.Add(key) if err != nil { - return nil, 0, 0, fmt.Errorf("adding map key to BTF encoder: %w", err) + return nil, 0, 0, fmt.Errorf("add key type: %w", err) } } if value != nil { - valueID, err = enc.Add(value) + valueID, err = spec.Add(value) if err != nil { - return nil, 0, 0, fmt.Errorf("adding map value to BTF encoder: %w", err) + return nil, 0, 0, fmt.Errorf("add value type: %w", err) } } - btf, err := enc.Encode() - if err != nil { - return nil, 0, 0, fmt.Errorf("marshal BTF: %w", err) - } - - handle, err := newHandleFromRawBTF(btf) + handle, err := NewHandle(spec) if err != nil { // Check for 'full' map BTF support, since kernels between 4.18 and 5.2 // already support BTF blobs for maps without Var or Datasec just fine. @@ -417,6 +460,5 @@ func MarshalMapKV(key, value Type) (_ *Handle, keyID, valueID TypeID, _ error) { return nil, 0, 0, err } } - return handle, keyID, valueID, err } diff --git a/btf/marshal_test.go b/btf/marshal_test.go index 4dd7f23d6..e58c9c4bb 100644 --- a/btf/marshal_test.go +++ b/btf/marshal_test.go @@ -3,6 +3,7 @@ package btf import ( "bytes" "encoding/binary" + "math" "math/rand" "testing" @@ -19,61 +20,58 @@ func TestBuild(t *testing.T) { Encoding: Signed | Char, } - enc := newEncoder(encoderOptions{ByteOrder: internal.NativeEndian}, nil) - - id, err := enc.Add(typ) - qt.Assert(t, err, qt.IsNil) - qt.Assert(t, id, qt.Equals, TypeID(1), qt.Commentf("First non-void type doesn't get id 1")) - - id, err = enc.Add(typ) - qt.Assert(t, err, qt.IsNil) - qt.Assert(t, id, qt.Equals, TypeID(1), qt.Commentf("Adding a type twice returns different ids")) + want := []Type{ + (*Void)(nil), + typ, + &Pointer{typ}, + &Typedef{"baz", typ}, + } - raw, err := enc.Encode() - qt.Assert(t, err, qt.IsNil, qt.Commentf("Build returned an error")) + var buf bytes.Buffer + qt.Assert(t, marshalTypes(&buf, want, nil, nil), qt.IsNil) - spec, err := loadRawSpec(bytes.NewReader(raw), internal.NativeEndian, nil, nil) + have, err := loadRawSpec(bytes.NewReader(buf.Bytes()), internal.NativeEndian, nil, nil) qt.Assert(t, err, qt.IsNil, qt.Commentf("Couldn't parse BTF")) - - have, err := spec.AnyTypeByName("foo") - qt.Assert(t, err, qt.IsNil) - qt.Assert(t, have, qt.DeepEquals, typ) + qt.Assert(t, have.types, qt.DeepEquals, want) } func TestRoundtripVMlinux(t *testing.T) { types := vmlinuxSpec(t).types // Randomize the order to force different permutations of walking the type - // graph. - rand.Shuffle(len(types), func(i, j int) { - types[i], types[j] = types[j], types[i] + // graph. Keep Void at index 0. + rand.Shuffle(len(types[1:]), func(i, j int) { + types[i+1], types[j+1] = types[j+1], types[i+1] }) - b := newEncoder(kernelEncoderOptions, nil) - + seen := make(map[Type]bool) +limitTypes: for i, typ := range types { - _, err := b.Add(typ) - qt.Assert(t, err, qt.IsNil, qt.Commentf("add type #%d: %s", i, typ)) - - if b.nextID >= 65_000 { + iter := postorderTraversal(typ, func(t Type) (skip bool) { + return seen[t] + }) + for iter.Next() { + seen[iter.Type] = true + } + if len(seen) >= math.MaxInt16 { // IDs exceeding math.MaxUint16 can trigger a bug when loading BTF. // This can be removed once the patch lands. // See https://lore.kernel.org/bpf/20220909092107.3035-1-oss@lmb.io/ - break + types = types[:i] + break limitTypes } } - nStr := len(b.strings.strings) - nTypes := len(types) - t.Log(nStr, "strings", nTypes, "types") - t.Log(float64(nStr)/float64(nTypes), "avg strings per type") + var buf bytes.Buffer + qt.Assert(t, marshalTypes(&buf, types, nil, nil), qt.IsNil) - raw, err := b.Encode() - qt.Assert(t, err, qt.IsNil, qt.Commentf("build BTF")) - - rebuilt, err := loadRawSpec(bytes.NewReader(raw), binary.LittleEndian, nil, nil) + rebuilt, err := loadRawSpec(bytes.NewReader(buf.Bytes()), binary.LittleEndian, nil, nil) qt.Assert(t, err, qt.IsNil, qt.Commentf("round tripping BTF failed")) + if n := len(rebuilt.types); n > math.MaxUint16 { + t.Logf("Rebuilt BTF contains %d types which exceeds uint16, test may fail on older kernels", n) + } + h, err := NewHandle(rebuilt) testutils.SkipIfNotSupported(t, err) qt.Assert(t, err, qt.IsNil, qt.Commentf("loading rebuilt BTF failed")) @@ -81,25 +79,14 @@ func TestRoundtripVMlinux(t *testing.T) { } func BenchmarkBuildVmlinux(b *testing.B) { - spec := vmlinuxTestdataSpec(b) + types := vmlinuxTestdataSpec(b).types b.ReportAllocs() b.ResetTimer() - types := spec.types - strings := spec.strings - for i := 0; i < b.N; i++ { - enc := newEncoder(encoderOptions{ByteOrder: internal.NativeEndian}, newStringTableBuilderFromTable(strings)) - - for _, typ := range types { - if _, err := enc.Add(typ); err != nil { - b.Fatal(err) - } - } - - _, err := enc.Encode() - if err != nil { + var buf bytes.Buffer + if err := marshalTypes(&buf, types, nil, nil); err != nil { b.Fatal(err) } } diff --git a/btf/strings.go b/btf/strings.go index deeaeacaa..d0eb604e8 100644 --- a/btf/strings.go +++ b/btf/strings.go @@ -89,15 +89,6 @@ func (st *stringTable) lookup(offset uint32) (string, error) { return st.strings[i], nil } -func (st *stringTable) Length() int { - if len(st.offsets) == 0 || len(st.strings) == 0 { - return 0 - } - - last := len(st.offsets) - 1 - return int(st.offsets[last]) + len(st.strings[last]) + 1 -} - func (st *stringTable) Marshal(w io.Writer) error { for _, str := range st.strings { _, err := io.WriteString(w, str) @@ -112,6 +103,11 @@ func (st *stringTable) Marshal(w io.Writer) error { return nil } +// Num returns the number of strings in the table. +func (st *stringTable) Num() int { + return len(st.strings) +} + // search is a copy of sort.Search specialised for uint32. // // Licensed under https://go.dev/LICENSE @@ -141,25 +137,19 @@ type stringTableBuilder struct { // newStringTableBuilder creates a builder with the given capacity. // // capacity may be zero. -func newStringTableBuilder() *stringTableBuilder { - stb := &stringTableBuilder{0, make(map[string]uint32)} - // Ensure that the empty string is at index 0. - stb.append("") - return stb -} +func newStringTableBuilder(capacity int) *stringTableBuilder { + var stb stringTableBuilder -// newStringTableBuilderFromTable creates a new builder from an existing string table. -func newStringTableBuilderFromTable(contents *stringTable) *stringTableBuilder { - stb := &stringTableBuilder{0, make(map[string]uint32, len(contents.strings)+1)} - stb.append("") - - for _, str := range contents.strings { - if str != "" { - stb.append(str) - } + if capacity == 0 { + // Use the runtime's small default size. + stb.strings = make(map[string]uint32) + } else { + stb.strings = make(map[string]uint32, capacity) } - return stb + // Ensure that the empty string is at index 0. + stb.append("") + return &stb } // Add a string to the table. @@ -203,19 +193,13 @@ func (stb *stringTableBuilder) Length() int { return int(stb.length) } -// Marshal a string table into its binary representation. -func (stb *stringTableBuilder) Marshal() []byte { - buf := make([]byte, stb.Length()) - stb.MarshalBuffer(buf) - return buf -} - -// Marshal a string table into a pre-allocated buffer. -// -// The buffer must be at least of size Length(). -func (stb *stringTableBuilder) MarshalBuffer(buf []byte) { +// AppendEncoded appends the string table to the end of the provided buffer. +func (stb *stringTableBuilder) AppendEncoded(buf []byte) []byte { + n := len(buf) + buf = append(buf, make([]byte, stb.Length())...) + strings := buf[n:] for str, offset := range stb.strings { - n := copy(buf[offset:], str) - buf[offset+uint32(n)] = 0 + copy(strings[offset:], str) } + return buf } diff --git a/btf/strings_test.go b/btf/strings_test.go index b44918d2c..6f2bc4d84 100644 --- a/btf/strings_test.go +++ b/btf/strings_test.go @@ -72,9 +72,9 @@ func TestStringTable(t *testing.T) { } func TestStringTableBuilder(t *testing.T) { - stb := newStringTableBuilder() + stb := newStringTableBuilder(0) - _, err := readStringTable(bytes.NewReader(stb.Marshal()), nil) + _, err := readStringTable(bytes.NewReader(stb.AppendEncoded(nil)), nil) qt.Assert(t, err, qt.IsNil, qt.Commentf("Can't parse string table")) _, err = stb.Add("foo\x00bar") @@ -88,7 +88,7 @@ func TestStringTableBuilder(t *testing.T) { foo2, _ := stb.Add("foo") qt.Assert(t, foo1, qt.Equals, foo2, qt.Commentf("Adding the same string returns different offsets")) - table := stb.Marshal() + table := stb.AppendEncoded(nil) if n := bytes.Count(table, []byte("foo")); n != 1 { t.Fatalf("Marshalled string table contains foo %d times instead of once", n) } diff --git a/btf/traversal.go b/btf/traversal.go index fa42815f3..a3a9dec94 100644 --- a/btf/traversal.go +++ b/btf/traversal.go @@ -15,7 +15,7 @@ type postorderIterator struct { // The root type. May be nil if skip(root) is true. root Type - // Contains types which need to be either walked or passed to the callback. + // Contains types which need to be either walked or yielded. types typeDeque // Contains a boolean whether the type has been walked or not. walked internal.Deque[bool] @@ -26,9 +26,8 @@ type postorderIterator struct { Type Type } -// postorderTraversal calls fn for all types reachable from root. -// -// fn is invoked on children of root before root itself. +// postorderTraversal iterates all types reachable from root by visiting the +// leaves of the graph first. // // Types for which skip returns true are ignored. skip may be nil. func postorderTraversal(root Type, skip func(Type) (skip bool)) postorderIterator { diff --git a/btf/types.go b/btf/types.go index e344bbdc3..970b8bc5e 100644 --- a/btf/types.go +++ b/btf/types.go @@ -17,6 +17,17 @@ const maxTypeDepth = 32 type TypeID uint32 // Type represents a type described by BTF. +// +// Identity of Type follows the [Go specification]: two Types are considered +// equal if they have the same concrete type and the same dynamic value, aka +// they point at the same location in memory. This means that the following +// Types are considered distinct even though they have the same "shape". +// +// a := &Int{Size: 1} +// b := &Int{Size: 1} +// a != b +// +// [Go specification]: https://go.dev/ref/spec#Comparison_operators type Type interface { // Type can be formatted using the %s and %v verbs. %s outputs only the // identity of the type, without any detail. %v outputs additional detail. diff --git a/internal/deque.go b/internal/deque.go index 05be23e61..e3a305021 100644 --- a/internal/deque.go +++ b/internal/deque.go @@ -24,24 +24,11 @@ func (dq *Deque[T]) Empty() bool { return dq.read == dq.write } -func (dq *Deque[T]) remainingCap() int { - return len(dq.elems) - int(dq.write-dq.read) -} - // Push adds an element to the end. func (dq *Deque[T]) Push(e T) { - if dq.remainingCap() >= 1 { - dq.elems[dq.write&dq.mask] = e - dq.write++ - return - } - - elems := dq.linearise(1) - elems = append(elems, e) - - dq.elems = elems[:cap(elems)] - dq.mask = uint64(cap(elems)) - 1 - dq.read, dq.write = 0, uint64(len(elems)) + dq.Grow(1) + dq.elems[dq.write&dq.mask] = e + dq.write++ } // Shift returns the first element or the zero value. @@ -74,16 +61,17 @@ func (dq *Deque[T]) Pop() T { return t } -// linearise the contents of the deque. -// -// The returned slice has space for at least n more elements and has power -// of two capacity. -func (dq *Deque[T]) linearise(n int) []T { - length := dq.write - dq.read - need := length + uint64(n) - if need < length { +// Grow the deque's capacity, if necessary, to guarantee space for another n +// elements. +func (dq *Deque[T]) Grow(n int) { + have := dq.write - dq.read + need := have + uint64(n) + if need < have { panic("overflow") } + if uint64(len(dq.elems)) >= need { + return + } // Round up to the new power of two which is at least 8. // See https://jameshfisher.com/2018/03/30/round-up-power-2/ @@ -92,9 +80,12 @@ func (dq *Deque[T]) linearise(n int) []T { capacity = 8 } - types := make([]T, length, capacity) + elems := make([]T, have, capacity) pivot := dq.read & dq.mask - copied := copy(types, dq.elems[pivot:]) - copy(types[copied:], dq.elems[:pivot]) - return types + copied := copy(elems, dq.elems[pivot:]) + copy(elems[copied:], dq.elems[:pivot]) + + dq.elems = elems[:capacity] + dq.mask = uint64(capacity) - 1 + dq.read, dq.write = 0, have } diff --git a/internal/deque_test.go b/internal/deque_test.go index d611c0719..88123e748 100644 --- a/internal/deque_test.go +++ b/internal/deque_test.go @@ -59,22 +59,24 @@ func TestDeque(t *testing.T) { } }) - t.Run("linearise", func(t *testing.T) { + t.Run("grow", func(t *testing.T) { var td Deque[int] td.Push(1) td.Push(2) + td.Push(3) + td.Shift() - all := td.linearise(0) - if len(all) != 2 { - t.Fatal("Expected 2 elements, got", len(all)) + td.Grow(7) + if len(td.elems) < 9 { + t.Fatal("Expected at least 9 elements, got", len(td.elems)) } - if cap(all)&(cap(all)-1) != 0 { - t.Fatalf("Capacity %d is not a power of two", cap(all)) + if cap(td.elems)&(cap(td.elems)-1) != 0 { + t.Fatalf("Capacity %d is not a power of two", cap(td.elems)) } - if all[0] != 1 || all[1] != 2 { - t.Fatal("Elements don't match") + if td.Shift() != 2 || td.Shift() != 3 { + t.Fatal("Elements don't match after grow") } }) }