Skip to content

Commit

Permalink
btf: make copying types infallible
Browse files Browse the repository at this point in the history
We currently have an unexported copyTypes function which allows transforming
a type graph before copying. This is very useful to strip qualifiers or
typedefs out of a graph. The function currently returns an error, since we
might not be able to remove all qualifiers, etc. in the case of deeply
nested or cyclical types.

Instead of returning an error, introduce an invalid sentinel Type "cycle" and
use it to signal that something went wrong. This allows removing the error return
which in turn allows us to extend Copy to the more powerful copyType paradigm.
  • Loading branch information
lmb committed Apr 21, 2022
1 parent e3e0e9e commit 9abf0c5
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 106 deletions.
2 changes: 1 addition & 1 deletion internal/btf/btf.go
Expand Up @@ -527,7 +527,7 @@ func fixupDatasec(rawTypes []rawType, rawStrings stringTable, sectionSizes map[s

// Copy creates a copy of Spec.
func (s *Spec) Copy() *Spec {
types, _ := copyTypes(s.types, nil)
types := copyTypes(s.types, nil)

namedTypes := make(map[essentialName][]Type)
for _, typ := range types {
Expand Down
38 changes: 2 additions & 36 deletions internal/btf/core.go
Expand Up @@ -269,19 +269,13 @@ var errImpossibleRelocation = errors.New("impossible relocation")
// the better the target is.
func coreCalculateFixups(byteOrder binary.ByteOrder, local Type, targets []Type, relos CORERelos) ([]COREFixup, error) {
localID := local.ID()
local, err := copyType(local, skipQualifiersAndTypedefs)
if err != nil {
return nil, err
}
local = Copy(local, UnderlyingType)

bestScore := len(relos)
var bestFixups []COREFixup
for i := range targets {
targetID := targets[i].ID()
target, err := copyType(targets[i], skipQualifiersAndTypedefs)
if err != nil {
return nil, err
}
target := Copy(targets[i], UnderlyingType)

score := 0 // lower is better
fixups := make([]COREFixup, 0, len(relos))
Expand Down Expand Up @@ -1009,31 +1003,3 @@ func coreAreMembersCompatible(localType Type, targetType Type) error {
return fmt.Errorf("type %s: %w", localType, ErrNotSupported)
}
}

func skipQualifiersAndTypedefs(typ Type) (Type, error) {
result := typ
for depth := 0; depth <= maxTypeDepth; depth++ {
switch v := (result).(type) {
case qualifier:
result = v.qualify()
case *Typedef:
result = v.Type
default:
return result, nil
}
}
return nil, errors.New("exceeded type depth")
}

func skipQualifiers(typ Type) (Type, error) {
result := typ
for depth := 0; depth <= maxTypeDepth; depth++ {
switch v := (result).(type) {
case qualifier:
result = v.qualify()
default:
return result, nil
}
}
return nil, errors.New("exceeded type depth")
}
11 changes: 5 additions & 6 deletions internal/btf/core_test.go
Expand Up @@ -584,8 +584,9 @@ func TestCORECopyWithoutQualifiers(t *testing.T) {
root := &Volatile{}
root.Type = test.fn(root)

_, err := copyType(root, skipQualifiersAndTypedefs)
qt.Assert(t, err, qt.Not(qt.IsNil))
cycle, ok := Copy(root, UnderlyingType).(*cycle)
qt.Assert(t, ok, qt.IsTrue)
qt.Assert(t, cycle.root, qt.Equals, root)
})
}

Expand All @@ -595,8 +596,7 @@ func TestCORECopyWithoutQualifiers(t *testing.T) {
v := a.fn(&Pointer{Target: b.fn(&Int{Name: "z"})})
want := &Pointer{Target: &Int{Name: "z"}}

got, err := copyType(v, skipQualifiersAndTypedefs)
qt.Assert(t, err, qt.IsNil)
got := Copy(v, UnderlyingType)
qt.Assert(t, got, qt.DeepEquals, want)
})
}
Expand All @@ -611,8 +611,7 @@ func TestCORECopyWithoutQualifiers(t *testing.T) {
t.Log(q.name)
}

got, err := copyType(v, skipQualifiersAndTypedefs)
qt.Assert(t, err, qt.IsNil)
got := Copy(v, UnderlyingType)
qt.Assert(t, got, qt.DeepEquals, root)
})
}
42 changes: 22 additions & 20 deletions internal/btf/format.go
Expand Up @@ -63,12 +63,7 @@ func (gf *GoFormatter) writeTypeDecl(name string, typ Type) error {
return fmt.Errorf("need a name for type %s", typ)
}

typ, err := skipQualifiers(typ)
if err != nil {
return err
}

switch v := typ.(type) {
switch v := skipQualifiers(typ).(type) {
case *Enum:
fmt.Fprintf(&gf.w, "type %s int32", name)
if len(v.Values) == 0 {
Expand All @@ -83,10 +78,11 @@ func (gf *GoFormatter) writeTypeDecl(name string, typ Type) error {
gf.w.WriteString(")")

return nil
}

fmt.Fprintf(&gf.w, "type %s ", name)
return gf.writeTypeLit(typ, 0)
default:
fmt.Fprintf(&gf.w, "type %s ", name)
return gf.writeTypeLit(v, 0)
}
}

// writeType outputs the name of a named type or a literal describing the type.
Expand All @@ -96,10 +92,7 @@ func (gf *GoFormatter) writeTypeDecl(name string, typ Type) error {
// foo (if foo is a named type)
// uint32
func (gf *GoFormatter) writeType(typ Type, depth int) error {
typ, err := skipQualifiers(typ)
if err != nil {
return err
}
typ = skipQualifiers(typ)

name := gf.Names[typ]
if name != "" {
Expand All @@ -124,12 +117,8 @@ func (gf *GoFormatter) writeTypeLit(typ Type, depth int) error {
return errNestedTooDeep
}

typ, err := skipQualifiers(typ)
if err != nil {
return err
}

switch v := typ.(type) {
var err error
switch v := skipQualifiers(typ).(type) {
case *Int:
gf.writeIntLit(v)

Expand All @@ -154,7 +143,7 @@ func (gf *GoFormatter) writeTypeLit(typ Type, depth int) error {
err = gf.writeDatasecLit(v, depth)

default:
return fmt.Errorf("type %s: %w", typ, ErrNotSupported)
return fmt.Errorf("type %T: %w", v, ErrNotSupported)
}

if err != nil {
Expand Down Expand Up @@ -302,3 +291,16 @@ func (gf *GoFormatter) writePadding(bytes uint32) {
fmt.Fprintf(&gf.w, "_ [%d]byte; ", bytes)
}
}

func skipQualifiers(typ Type) Type {
result := typ
for depth := 0; depth <= maxTypeDepth; depth++ {
switch v := (result).(type) {
case qualifier:
result = v.qualify()
default:
return result
}
}
return &cycle{typ}
}
65 changes: 27 additions & 38 deletions internal/btf/types.go
Expand Up @@ -541,6 +541,20 @@ func (f *Float) copy() Type {
return &cpy
}

// cycle is a type which had to be elided due since it exceeded maxNestingDepth.
type cycle struct {
root Type
}

func (c *cycle) ID() TypeID { return math.MaxUint32 }
func (c *cycle) String() string { return fmt.Sprintf("cycle[%s]", c.root) }
func (c *cycle) TypeName() string { return "" }
func (c *cycle) walk(*typeDeque) {}
func (c *cycle) copy() Type {
cpy := *c
return &cpy
}

type sizer interface {
size() uint32
}
Expand Down Expand Up @@ -620,12 +634,7 @@ func Sizeof(typ Type) (int, error) {
//
// Currently only supports the subset of types necessary for bitfield relocations.
func alignof(typ Type) (int, error) {
typ, err := skipQualifiersAndTypedefs(typ)
if err != nil {
return 0, err
}

switch t := typ.(type) {
switch t := UnderlyingType(typ).(type) {
case *Enum:
return int(t.size()), nil
case *Int:
Expand All @@ -635,44 +644,34 @@ func alignof(typ Type) (int, error) {
}
}

// Copy a Type recursively.
func Copy(typ Type) Type {
typ, _ = copyType(typ, nil)
return typ
}

// copy a Type recursively.
//
// typ may form a cycle.
//
// Returns any errors from transform verbatim.
func copyType(typ Type, transform func(Type) (Type, error)) (Type, error) {
// typ may form a cycle. If transform is not nil, it is called with the
// to be copied type, and the return value is copied instead.
func Copy(typ Type, transform func(Type) Type) Type {
copies := make(copier)
return typ, copies.copy(&typ, transform)
copies.copy(&typ, transform)
return typ
}

// copy a slice of Types recursively.
//
// Types may form a cycle.
//
// Returns any errors from transform verbatim.
func copyTypes(types []Type, transform func(Type) (Type, error)) ([]Type, error) {
// See Copy for the semantics.
func copyTypes(types []Type, transform func(Type) Type) []Type {
result := make([]Type, len(types))
copy(result, types)

copies := make(copier)
for i := range result {
if err := copies.copy(&result[i], transform); err != nil {
return nil, err
}
copies.copy(&result[i], transform)
}

return result, nil
return result
}

type copier map[Type]Type

func (c copier) copy(typ *Type, transform func(Type) (Type, error)) error {
func (c copier) copy(typ *Type, transform func(Type) Type) {
var work typeDeque
for t := typ; t != nil; t = work.pop() {
// *t is the identity of the type.
Expand All @@ -683,11 +682,7 @@ func (c copier) copy(typ *Type, transform func(Type) (Type, error)) error {

var cpy Type
if transform != nil {
tf, err := transform(*t)
if err != nil {
return fmt.Errorf("copy %s: %w", *t, err)
}
cpy = tf.copy()
cpy = transform(*t).copy()
} else {
cpy = (*t).copy()
}
Expand All @@ -698,8 +693,6 @@ func (c copier) copy(typ *Type, transform func(Type) (Type, error)) error {
// Mark any nested types for copying.
cpy.walk(&work)
}

return nil
}

// typeDeque keeps track of pointers to types which still
Expand Down Expand Up @@ -1011,9 +1004,6 @@ func newEssentialName(name string) essentialName {
}

// UnderlyingType skips qualifiers and Typedefs.
//
// May return typ verbatim if too many types have to be skipped to protect against
// circular Types.
func UnderlyingType(typ Type) Type {
result := typ
for depth := 0; depth <= maxTypeDepth; depth++ {
Expand All @@ -1026,6 +1016,5 @@ func UnderlyingType(typ Type) Type {
return result
}
}
// Return the original argument, since we can't find an underlying type.
return typ
return &cycle{typ}
}
9 changes: 5 additions & 4 deletions internal/btf/types_test.go
Expand Up @@ -35,24 +35,24 @@ func TestSizeof(t *testing.T) {
}

func TestCopyType(t *testing.T) {
_, _ = copyType((*Void)(nil), nil)
_ = Copy((*Void)(nil), nil)

in := &Int{Size: 4}
out, _ := copyType(in, nil)
out := Copy(in, nil)

in.Size = 8
if size := out.(*Int).Size; size != 4 {
t.Error("Copy doesn't make a copy, expected size 4, got", size)
}

t.Run("cyclical", func(t *testing.T) {
_, _ = copyType(newCyclicalType(2), nil)
_ = Copy(newCyclicalType(2), nil)
})

t.Run("identity", func(t *testing.T) {
u16 := &Int{Size: 2}

out, _ := copyType(&Struct{
out := Copy(&Struct{
Members: []Member{
{Name: "a", Type: u16},
{Name: "b", Type: u16},
Expand Down Expand Up @@ -122,6 +122,7 @@ func TestType(t *testing.T) {
Vars: []VarSecinfo{{Type: &Void{}}},
}
},
func() Type { return &cycle{&Void{}} },
}

compareTypes := cmp.Comparer(func(a, b *Type) bool {
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/gentypes/main.go
Expand Up @@ -471,7 +471,7 @@ import (
}

func outputPatchedStruct(gf *btf.GoFormatter, w *bytes.Buffer, id string, s *btf.Struct, patches []patch) error {
s = btf.Copy(s).(*btf.Struct)
s = btf.Copy(s, nil).(*btf.Struct)

for i, p := range patches {
if err := p(s); err != nil {
Expand Down

0 comments on commit 9abf0c5

Please sign in to comment.