Skip to content

Commit

Permalink
simplify code by using format.Node
Browse files Browse the repository at this point in the history
  • Loading branch information
josharian committed Jan 17, 2024
1 parent e0454a7 commit 5fa0fa9
Showing 1 changed file with 20 additions and 163 deletions.
183 changes: 20 additions & 163 deletions impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,151 +170,29 @@ func findInterface(input string, srcDir string) (path string, iface Type, err er
}

func typeFromAST(in ast.Expr) (Type, error) {
switch specType := in.(type) {
case *ast.Ident:
// a standalone identifier (Reader) shows up as an Ident
return Type{Name: specType.Name}, nil
case *ast.SelectorExpr:
// an identifier in a different package (io.Reader) shows up as a SelectorExpr
// we need to pull the name out
return Type{Name: specType.X.(*ast.Ident).Name + "." + specType.Sel.Name}, nil
case *ast.StarExpr:
// pointer identifiers (*Reader) show up as a StarExpr
// we need to pull the name out and prefix it with a *
typ, err := typeFromAST(specType.X)
if err != nil {
return Type{}, err
}
typ.Name = "*" + typ.Name
return typ, nil
case *ast.ArrayType:
// slices and arrays ([]Reader) show up as an ArrayType
typ, err := typeFromAST(specType.Elt)
if err != nil {
return Type{}, err
}
prefix := "["
if specType.Len != nil {
prefix += specType.Len.(*ast.BasicLit).Value
}
prefix += "]"
typ.Name = prefix + typ.Name
return typ, nil
case *ast.MapType:
// maps (map[string]Reader) show up as a MapType
key, err := typeFromAST(specType.Key)
if err != nil {
return Type{}, err
}
value, err := typeFromAST(specType.Value)
if err != nil {
return Type{}, err
}
return Type{
Name: "map[" + key.String() + "]" + value.String(),
}, nil
case *ast.FuncType:
// funcs (func() Reader) show up as a FuncType
// NOTE: we don't actually parse out the type params of a FuncType.
// This should be okay, because we really only care about
// parsing out the type params when parsing interface
// identifiers. And FuncTypes never signify an interface
// identifier, they're just an argument to it or a type param
// of it.
// We don't parse them out anyways like we do for everything
// else because funcs, alone, are pretty weird in how they use
// generics.
// For everything else, it's identifier[params].
// For funcs, the params get stuck in the middle of the identifier:
// func Foo[Param1, Param2](context.Context, Param1) Param2
// We're gonna need to complicate everything to support that
// construction, and we don't actually need the deconstructed
// bits, so we're just... not going to deconstruct it at all.
var res strings.Builder
res.WriteString("func")
if specType.TypeParams != nil && len(specType.TypeParams.List) > 0 {
res.WriteString("[")
paramList, err := buildFuncParamList(specType.TypeParams.List)
if err != nil {
return Type{}, err
}
res.WriteString(paramList)
res.WriteString("]")
}
res.WriteString("(")
if specType.Params != nil {
paramList, err := buildFuncParamList(specType.Params.List)
if err != nil {
return Type{}, err
}
res.WriteString(paramList)
}
res.WriteString(")")
if specType.Results != nil && len(specType.Results.List) > 0 {
res.WriteString(" ")
if len(specType.Results.List) > 1 {
res.WriteString("(")
}
paramList, err := buildFuncParamList(specType.Results.List)
if err != nil {
return Type{}, err
}
res.WriteString(paramList)
if len(specType.Results.List) > 1 {
res.WriteString(")")
}
}
return Type{Name: res.String()}, nil
case *ast.ChanType:
var res strings.Builder
// channels (chan Reader) show up as a ChanType
// we need to be careful to preserve send/receive semantics
if specType.Dir&ast.SEND == 0 {
// this is a receive-only channel, write the arrow before the chan keyword
res.WriteString("<-")
}
res.WriteString("chan")
if specType.Dir&ast.RECV == 0 {
// this is a send-only channel, write the arrow after the chan keyword
res.WriteString("<-")
}
res.WriteString(" ")
valType, err := typeFromAST(specType.Value)
if err != nil {
return Type{}, err
}
valType.Name = res.String() + valType.Name
return valType, nil
// Extract type name and params from generic types.
var typeName ast.Expr
var typeParams []ast.Expr
switch in := in.(type) {
case *ast.IndexExpr:
// a generic type with one type parameter (Reader[Foo]) shows up as an IndexExpr
id, err := typeFromAST(specType.X)
if err != nil {
return Type{}, err
}
if len(id.Params) > 0 {
return Type{}, fmt.Errorf("got type parameters for a type name: %s", id.String())
}
param, err := typeFromAST(specType.Index)
if err != nil {
return Type{}, err
}
return Type{
Name: id.Name,
Params: []string{param.String()},
}, nil
typeName = in.X
typeParams = []ast.Expr{in.Index}
case *ast.IndexListExpr:
// a generic type with multiple type parameters shows up as an IndexListExpr
id, err := typeFromAST(specType.X)
typeName = in.X
typeParams = in.Indices
}
if typeParams != nil {
id, err := typeFromAST(typeName)
if err != nil {
return Type{}, err
}
if len(id.Params) > 0 {
return Type{}, fmt.Errorf("got type parameters for a type ID: %s", id.String())
return Type{}, fmt.Errorf("unexpected type parameters: %v", in)
}
res := Type{
Name: id.Name,
}
for _, typeParam := range specType.Indices {
res := Type{Name: id.Name}
for _, typeParam := range typeParams {
param, err := typeFromAST(typeParam)
if err != nil {
return Type{}, err
Expand All @@ -323,34 +201,13 @@ func typeFromAST(in ast.Expr) (Type, error) {
}
return res, nil
}
return Type{}, fmt.Errorf("unexpected AST type %T", in)
}

// buildFuncParamList returns a string representation of a list of function
// params (type params, function arguments, returns) given an []*ast.Field
// for those things.
func buildFuncParamList(list []*ast.Field) (string, error) {
var res strings.Builder
for pos, field := range list {
for namePos, name := range field.Names {
res.WriteString(name.Name)
if namePos+1 < len(field.Names) {
res.WriteString(", ")
}
}
if len(field.Names) > 0 {
res.WriteString(" ")
}
fieldType, err := typeFromAST(field.Type)
if err != nil {
return "", err
}
res.WriteString(fieldType.String())
if pos+1 < len(list) {
res.WriteString(", ")
}
// Non-generic type.
buf := new(strings.Builder)
err := format.Node(buf, token.NewFileSet(), in)
if err != nil {
return Type{}, err
}
return res.String(), nil
return Type{Name: buf.String()}, nil
}

// Pkg is a parsed build.Package.
Expand Down

0 comments on commit 5fa0fa9

Please sign in to comment.