From 7f70e7c2fb204fd80012e50d4f98938ef8b8b073 Mon Sep 17 00:00:00 2001 From: Paddy Carver Date: Sun, 18 Dec 2022 20:14:59 -0800 Subject: [PATCH 01/20] Update for Go 1.19 and support generics. Update go.mod to indicate support for Go 1.19. Add support for generic interfaces, i.e. interfaces that accept a type parameter. For example, you can now have this kind of interface: ```go type Interface[Kind any] interface { DoTheThing() Kind } ``` and if you run `impl 's StringImpl' 'Interface[string]` it will generate the following code: ```go func (s StringImpl) DoTheThing() string { // normal impl stub here } ``` Fixes josharian/impl#44. --- go.mod | 8 ++- go.sum | 42 ----------- impl.go | 158 +++++++++++++++++++++++++++++++++-------- impl_test.go | 93 +++++++++++++++++++----- testdata/interfaces.go | 96 +++++++++++++++++++++++++ 5 files changed, 305 insertions(+), 92 deletions(-) diff --git a/go.mod b/go.mod index 4ed327c..14b0012 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,10 @@ module github.com/josharian/impl -go 1.14 +go 1.19 + +require golang.org/x/tools v0.4.0 require ( - golang.org/x/tools v0.4.0 - golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect + golang.org/x/mod v0.7.0 // indirect + golang.org/x/sys v0.3.0 // indirect ) diff --git a/go.sum b/go.sum index 626cea3..372a637 100644 --- a/go.sum +++ b/go.sum @@ -1,48 +1,6 @@ -github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/mod v0.2.0 h1:KU7oHjnv3XNWfa5COkzUifxZmxp1TyI7ImMXqFxLwvQ= -golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375 h1:SjQ2+AKWgZLc1xej6WSzL+Dfs5Uyd5xcZH1mGC411IA= -golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4= golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= -golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= diff --git a/impl.go b/impl.go index 17b16dc..307aac3 100644 --- a/impl.go +++ b/impl.go @@ -26,6 +26,34 @@ var ( flagRecvPkg = flag.String("recvpkg", "", "package name of the receiver") ) +func parseTypeParams(in string) (string, []string, error) { + firstOpenBracket := strings.Index(in, "[") + if firstOpenBracket < 0 { + return in, []string{}, nil + } + // there are type parameters in our interface + id := in[:firstOpenBracket] + firstCloseBracket := strings.LastIndex(in, "]") + if firstCloseBracket < 0 { + // make sure we're closing our list of type parameters + return "", nil, fmt.Errorf("invalid interface name (cannot have [ without ]): %s", in) + } + if firstCloseBracket != len(in)-1 { + // make sure the first close bracket is actually the last character of the interface name + return "", nil, fmt.Errorf("invalid interface name (cannot have ] anywhere except the last character): %s", in) + } + params := strings.Split(in[firstOpenBracket+1:firstCloseBracket], ",") + typeParams := make([]string, 0, len(params)) + for _, param := range params { + typeParams = append(typeParams, strings.TrimSpace(param)) + } + if len(typeParams) < 1 { + // make sure if we're declaring type parameters, we declare at least one + return "", nil, fmt.Errorf("invalid interface name (cannot have empty type parameters): %s", in) + } + return id, typeParams, nil +} + // findInterface returns the import path and identifier of an interface. // For example, given "http.ResponseWriter", findInterface returns // "net/http", "ResponseWriter". @@ -34,9 +62,9 @@ var ( // If an unqualified interface such as "UserDefinedInterface" is given, then // the interface definition is presumed to be in the package within srcDir and // findInterface returns "", "UserDefinedInterface". -func findInterface(iface string, srcDir string) (path string, id string, err error) { - if len(strings.Fields(iface)) != 1 { - return "", "", fmt.Errorf("couldn't parse interface: %s", iface) +func findInterface(iface string, srcDir string) (path string, id string, typeParams []string, err error) { + if len(strings.Fields(iface)) != 1 && !strings.Contains(iface, "[") { + return "", "", nil, fmt.Errorf("couldn't parse interface: %s", iface) } srcPath := filepath.Join(srcDir, "__go_impl__.go") @@ -46,17 +74,23 @@ func findInterface(iface string, srcDir string) (path string, id string, err err dot := strings.LastIndex(iface, ".") // make sure iface does not end with "/" (e.g. reject net/http/) if slash+1 == len(iface) { - return "", "", fmt.Errorf("interface name cannot end with a '/' character: %s", iface) + return "", "", nil, fmt.Errorf("interface name cannot end with a '/' character: %s", iface) } // make sure iface does not end with "." (e.g. reject net/http.) if dot+1 == len(iface) { - return "", "", fmt.Errorf("interface name cannot end with a '.' character: %s", iface) + return "", "", nil, fmt.Errorf("interface name cannot end with a '.' character: %s", iface) } // make sure iface has at least one "." after "/" (e.g. reject net/http/httputil) if strings.Count(iface[slash:], ".") == 0 { - return "", "", fmt.Errorf("invalid interface name: %s", iface) + return "", "", nil, fmt.Errorf("invalid interface name: %s", iface) + } + path = iface[:dot] + id = iface[dot+1:] + id, typeParams, err = parseTypeParams(id) + if err != nil { + return "", "", nil, err } - return iface[:dot], iface[dot+1:], nil + return path, id, typeParams, nil } src := []byte("package hack\n" + "var i " + iface) @@ -64,7 +98,7 @@ func findInterface(iface string, srcDir string) (path string, id string, err err // auto fix the import path. imp, err := imports.Process(srcPath, src, nil) if err != nil { - return "", "", fmt.Errorf("couldn't parse interface: %s", iface) + return "", "", nil, fmt.Errorf("couldn't parse interface: %s", iface) } // imp should now contain an appropriate import. @@ -78,7 +112,7 @@ func findInterface(iface string, srcDir string) (path string, id string, err err qualified := strings.Contains(iface, ".") if len(f.Imports) == 0 && qualified { - return "", "", fmt.Errorf("unrecognized interface: %s", iface) + return "", "", nil, fmt.Errorf("unrecognized interface: %s", iface) } if !qualified { @@ -89,10 +123,22 @@ func findInterface(iface string, srcDir string) (path string, id string, err err // var i Reader decl := f.Decls[0].(*ast.GenDecl) // var i io.Reader spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader - sel := spec.Type.(*ast.Ident) - id = sel.Name // Reader + if indxExpr, ok := spec.Type.(*ast.IndexExpr); ok { + // a generic type with one type parameter shows up as an IndexExpr + id = indxExpr.X.(*ast.Ident).Name + typeParams = append(typeParams, indxExpr.Index.(*ast.Ident).Name) + } else if indxListExpr, ok := spec.Type.(*ast.IndexListExpr); ok { + // a generic type with multiple type parameters shows up as an IndexListExpr + id = indxListExpr.X.(*ast.Ident).Name + for _, typeParam := range indxListExpr.Indices { + typeParams = append(typeParams, typeParam.(*ast.Ident).Name) + } + } else { + sel := spec.Type.(*ast.Ident) + id = sel.Name // Reader + } - return path, id, nil + return path, id, typeParams, nil } // If qualified, the code looks like: @@ -111,10 +157,22 @@ func findInterface(iface string, srcDir string) (path string, id string, err err } decl := f.Decls[1].(*ast.GenDecl) // var i io.Reader spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader - sel := spec.Type.(*ast.SelectorExpr) // io.Reader - id = sel.Sel.Name // Reader + if indxExpr, ok := spec.Type.(*ast.IndexExpr); ok { + // a generic type with one type parameter shows up as an IndexExpr + id = indxExpr.X.(*ast.SelectorExpr).Sel.Name + typeParams = append(typeParams, indxExpr.Index.(*ast.Ident).Name) + } else if indxListExpr, ok := spec.Type.(*ast.IndexListExpr); ok { + // a generic type with multiple type parameters shows up as an IndexListExpr + id = indxListExpr.X.(*ast.SelectorExpr).Sel.Name + for _, typeParam := range indxListExpr.Indices { + typeParams = append(typeParams, typeParam.(*ast.Ident).Name) + } + } else { + sel := spec.Type.(*ast.SelectorExpr) // io.Reader + id = sel.Sel.Name // Reader + } - return path, id, nil + return path, id, typeParams, nil } // Pkg is a parsed build.Package. @@ -125,20 +183,27 @@ type Pkg struct { recvPkg string } +// Spec is ast.TypeSpec with the associated comment map. +type Spec struct { + *ast.TypeSpec + ast.CommentMap + TypeParams map[string]string +} + // typeSpec locates the *ast.TypeSpec for type id in the import path. -func typeSpec(path string, id string, srcDir string) (Pkg, *ast.TypeSpec, error) { +func typeSpec(path, id string, typeParams []string, srcDir string) (Pkg, Spec, error) { var pkg *build.Package var err error if path == "" { pkg, err = build.ImportDir(srcDir, 0) if err != nil { - return Pkg{}, nil, fmt.Errorf("couldn't find package in %s: %v", srcDir, err) + return Pkg{}, Spec{}, fmt.Errorf("couldn't find package in %s: %v", srcDir, err) } } else { pkg, err = build.Import(path, srcDir, 0) if err != nil { - return Pkg{}, nil, fmt.Errorf("couldn't find package %s: %v", path, err) + return Pkg{}, Spec{}, fmt.Errorf("couldn't find package %s: %v", path, err) } } @@ -153,8 +218,12 @@ func typeSpec(path string, id string, srcDir string) (Pkg, *ast.TypeSpec, error) } for _, decl := range f.Decls { - decl, ok := decl.(*ast.GenDecl) - if !ok || decl.Tok != token.TYPE { + genDecl, ok := decl.(*ast.GenDecl) + if !ok { + continue + } + decl := genDecl + if decl.Tok != token.TYPE { continue } for _, spec := range decl.Specs { @@ -162,12 +231,31 @@ func typeSpec(path string, id string, srcDir string) (Pkg, *ast.TypeSpec, error) if spec.Name.Name != id { continue } + tParams := make(map[string]string, len(typeParams)) + if spec.TypeParams != nil { + var specParamNames []string + for _, typeParam := range spec.TypeParams.List { + for _, name := range typeParam.Names { + if name == nil { + continue + } + specParamNames = append(specParamNames, name.Name) + } + } + if len(specParamNames) != len(typeParams) { + continue + } + for pos, specParamName := range specParamNames { + tParams[specParamName] = typeParams[pos] + } + } p := Pkg{Package: pkg, FileSet: fset} - return p, spec, nil + s := Spec{TypeSpec: spec, TypeParams: tParams} + return p, s, nil } } } - return Pkg{}, nil, fmt.Errorf("type %s not found in %s", id, path) + return Pkg{}, Spec{}, fmt.Errorf("type %s not found in %s", id, path) } // gofmt pretty-prints e. @@ -203,9 +291,17 @@ func (p Pkg) fullType(e ast.Expr) string { return p.gofmt(e) } -func (p Pkg) params(field *ast.Field) []Param { +func (p Pkg) params(field *ast.Field, genericTypes map[string]string) []Param { var params []Param - typ := p.fullType(field.Type) + var typ string + ident, ok := field.Type.(*ast.Ident) + if !ok || ident == nil { + typ = p.fullType(field.Type) + } else if genType, ok := genericTypes[ident.Name]; ok { + typ = genType + } else { + typ = p.fullType(field.Type) + } for _, name := range field.Names { params = append(params, Param{Name: name.Name, Type: typ}) } @@ -244,12 +340,12 @@ const ( WithoutComments EmitComments = false ) -func (p Pkg) funcsig(f *ast.Field, comments EmitComments) Func { +func (p Pkg) funcsig(f *ast.Field, genericParams map[string]string, cmap ast.CommentMap, comments EmitComments) Func { fn := Func{Name: f.Names[0].Name} typ := f.Type.(*ast.FuncType) if typ.Params != nil { for _, field := range typ.Params.List { - for _, param := range p.params(field) { + for _, param := range p.params(field, genericParams) { // only for method parameters: // assign a blank identifier "_" to an anonymous parameter if param.Name == "" { @@ -261,7 +357,7 @@ func (p Pkg) funcsig(f *ast.Field, comments EmitComments) Func { } if typ.Results != nil { for _, field := range typ.Results.List { - fn.Res = append(fn.Res, p.params(field)...) + fn.Res = append(fn.Res, p.params(field, genericParams)...) } } if comments == WithComments && f.Doc != nil { @@ -286,13 +382,13 @@ func funcs(iface, srcDir, recvPkg string, comments EmitComments) ([]Func, error) } // Locate the interface. - path, id, err := findInterface(iface, srcDir) + path, id, typeParams, err := findInterface(iface, srcDir) if err != nil { return nil, err } // Parse the package and find the interface declaration. - p, spec, err := typeSpec(path, id, srcDir) + p, spec, err := typeSpec(path, id, typeParams, srcDir) if err != nil { return nil, fmt.Errorf("interface %s not found: %s", iface, err) } @@ -319,7 +415,7 @@ func funcs(iface, srcDir, recvPkg string, comments EmitComments) ([]Func, error) continue } - fn := p.funcsig(fndecl, comments) + fn := p.funcsig(fndecl, spec.TypeParams, spec.CommentMap.Filter(fndecl), comments) fns = append(fns, fn) } return fns, nil @@ -450,7 +546,7 @@ to prevent shell globbing. recvs := strings.Fields(recv) receiver := recvs[len(recvs)-1] // note that this correctly handles "s *Struct" and "*Struct" receiver = strings.TrimPrefix(receiver, "*") - pkg, _, err := typeSpec("", receiver, *flagSrcDir) + pkg, _, err := typeSpec("", receiver, nil, *flagSrcDir) if err == nil { recvPkg = pkg.Package.Name } diff --git a/impl_test.go b/impl_test.go index ffed4e4..0a146b7 100644 --- a/impl_test.go +++ b/impl_test.go @@ -20,10 +20,11 @@ func (b errBool) String() string { func TestFindInterface(t *testing.T) { t.Parallel() cases := []struct { - iface string - path string - id string - wantErr bool + iface string + path string + id string + typeParams []string + wantErr bool }{ {iface: "net.Conn", path: "net", id: "Conn"}, {iface: "http.ResponseWriter", path: "net/http", id: "ResponseWriter"}, @@ -34,13 +35,14 @@ func TestFindInterface(t *testing.T) { {iface: "a/b/c/pkg.", wantErr: true}, {iface: "a/b/c/pkg.Typ", path: "a/b/c/pkg", id: "Typ"}, {iface: "gopkg.in/yaml.v2.Unmarshaler", path: "gopkg.in/yaml.v2", id: "Unmarshaler"}, + {iface: "github.com/josharian/impl/testdata.GenericInterface1[string]", path: "github.com/josharian/impl/testdata", id: "GenericInterface1", typeParams: []string{"string"}}, } for _, tt := range cases { tt := tt t.Run(tt.iface, func(t *testing.T) { t.Parallel() - path, id, err := findInterface(tt.iface, ".") + path, id, typeParams, err := findInterface(tt.iface, ".") gotErr := err != nil if tt.wantErr != gotErr { t.Fatalf("findInterface(%q).err=%v want %s", tt.iface, err, errBool(tt.wantErr)) @@ -51,6 +53,14 @@ func TestFindInterface(t *testing.T) { if tt.id != id { t.Errorf("findInterface(%q).id=%q want %q", tt.iface, id, tt.id) } + if len(tt.typeParams) != len(typeParams) { + t.Errorf("findInterface(%q).len(typeParams)=%d want %d", tt.iface, len(typeParams), len(tt.typeParams)) + } + for pos, v := range tt.typeParams { + if v != typeParams[pos] { + t.Errorf("findInterface(%q).typeParams[%d]=%q, want %q", tt.iface, pos, typeParams[pos], v) + } + } }) } } @@ -67,7 +77,7 @@ func TestTypeSpec(t *testing.T) { } for _, tt := range cases { - p, spec, err := typeSpec(tt.path, tt.id, "") + p, spec, err := typeSpec(tt.path, tt.id, nil, "") gotErr := err != nil if tt.wantErr != gotErr { t.Errorf("typeSpec(%q, %q).err=%v want %s", tt.path, tt.id, err, errBool(tt.wantErr)) @@ -77,8 +87,8 @@ func TestTypeSpec(t *testing.T) { if reflect.DeepEqual(p, Pkg{}) { t.Errorf("typeSpec(%q, %q).pkg=Pkg{} want non-nil", tt.path, tt.id) } - if spec == nil { - t.Errorf("typeSpec(%q, %q).spec=nil want non-nil", tt.path, tt.id) + if reflect.DeepEqual(spec, Spec{}) { + t.Errorf("typeSpec(%q, %q).spec=Spec{} want non-nil", tt.path, tt.id) } } } @@ -252,6 +262,25 @@ func TestFuncs(t *testing.T) { comments: WithComments, }, {iface: "net.Tennis", wantErr: true}, + { + iface: "github.com/josharian/impl/testdata.GenericInterface1[int]", + want: []Func{ + { + Name: "Method1", + Res: []Param{{Type: "int"}}, + }, + { + Name: "Method2", + Params: []Param{{Name: "_", Type: "int"}}, + }, + { + Name: "Method3", + Params: []Param{{Name: "_", Type: "int"}}, + Res: []Param{{Type: "int"}}, + }, + }, + comments: WithComments, + }, } for _, tt := range cases { @@ -577,16 +606,48 @@ func TestStubGeneration(t *testing.T) { want: testdata.Interface9Output, dir: ".", }, + { + iface: "github.com/josharian/impl/testdata.GenericInterface1[string]", + want: testdata.GenericInterface1Output, + dir: ".", + }, + { + iface: "GenericInterface1[string]", + want: testdata.GenericInterface1Output, + dir: "testdata", + }, + { + iface: "github.com/josharian/impl/testdata.GenericInterface2[string, bool]", + want: testdata.GenericInterface2Output, + dir: ".", + }, + { + iface: "GenericInterface2[string, bool]", + want: testdata.GenericInterface2Output, + dir: "testdata", + }, + { + iface: "github.com/josharian/impl/testdata.GenericInterface3[string, bool]", + want: testdata.GenericInterface3Output, + dir: ".", + }, + { + iface: "GenericInterface3[string, bool]", + want: testdata.GenericInterface3Output, + dir: "testdata", + }, } for _, tt := range cases { - fns, err := funcs(tt.iface, tt.dir, "", WithComments) - if err != nil { - t.Errorf("funcs(%q).err=%v", tt.iface, err) - } - src := genStubs("r *Receiver", fns, nil) - if string(src) != tt.want { - t.Errorf("genStubs(\"r *Receiver\", %+#v).src=\n%#v\nwant\n%#v\n", fns, string(src), tt.want) - } + t.Run(tt.iface, func(t *testing.T) { + fns, err := funcs(tt.iface, tt.dir, "", WithComments) + if err != nil { + t.Errorf("funcs(%q).err=%v", tt.iface, err) + } + src := genStubs("r *Receiver", fns, nil) + if string(src) != tt.want { + t.Errorf("genStubs(\"r *Receiver\", %+#v).src=\n%#v\nwant\n%#v\n", fns, string(src), tt.want) + } + }) } } diff --git a/testdata/interfaces.go b/testdata/interfaces.go index 3167166..3b9f25b 100644 --- a/testdata/interfaces.go +++ b/testdata/interfaces.go @@ -48,6 +48,42 @@ type Interface3 interface { Method3(arg1, arg2 bool) (result1, result2 bool) } +// GenericInterface1 is a dummy interface to test the program output. This +// interface tests generation of generic interfaces with the specified type +// parameters. +type GenericInterface1[Type any] interface { + // Method1 is the first method of GenericInterface1. + Method1() Type + // Method2 is the second method of GenericInterface1. + Method2(Type) + // Method3 is the third method of GenericInterface1. + Method3(Type) Type +} + +// GenericInterface2 is a dummy interface to test the program output. This +// interface tests generation of generic interfaces with the specified type +// parameters. +type GenericInterface2[Type1 any, Type2 comparable] interface { + // Method1 is the first method of GenericInterface2. + Method1() (Type1, Type2) + // Method2 is the second method of GenericInterface2. + Method2(Type1, Type2) + // Method3 is the third method of GenericInterface2. + Method3(Type1) Type2 +} + +// GenericInterface3 is a dummy interface to test the program output. This +// interface tests generation of generic interfaces with repeated type +// parameters. +type GenericInterface3[Type1, Type2 any] interface { + // Method1 is the first method of GenericInterface3. + Method1() (Type1, Type2) + // Method2 is the second method of GenericInterface3. + Method2(Type1, Type2) + // Method3 is the third method of GenericInterface3. + Method3(Type1) Type2 +} + // Interface1Output is the expected output generated from reflecting on // Interface1, provided that the receiver is equal to 'r *Receiver'. var Interface1Output = `// Method1 is the first method of Interface1. @@ -172,3 +208,63 @@ func (arg3 *Implemented) Method2(arg1 string, arg2 int) (_ error) { } ` + +// GenericInterface1Output is the expected output generated from reflecting on +// GenericInterface1, provided that the receiver is equal to 'r *Receiver' and +// it was generated with the type parameters [string]. +var GenericInterface1Output = `// Method1 is the first method of GenericInterface1. +func (r *Receiver) Method1() string { + panic("not implemented") // TODO: Implement +} + +// Method2 is the second method of GenericInterface1. +func (r *Receiver) Method2(_ string) { + panic("not implemented") // TODO: Implement +} + +// Method3 is the third method of GenericInterface1. +func (r *Receiver) Method3(_ string) string { + panic("not implemented") // TODO: Implement +} + +` + +// GenericInterface2Output is the expected output generated from reflecting on +// GenericInterface2, provided that the receiver is equal to 'r *Receiver' and +// it was generated with the type parameters [string, bool]. +var GenericInterface2Output = `// Method1 is the first method of GenericInterface2. +func (r *Receiver) Method1() (string, bool) { + panic("not implemented") // TODO: Implement +} + +// Method2 is the second method of GenericInterface2. +func (r *Receiver) Method2(_ string, _ bool) { + panic("not implemented") // TODO: Implement +} + +// Method3 is the third method of GenericInterface2. +func (r *Receiver) Method3(_ string) bool { + panic("not implemented") // TODO: Implement +} + +` + +// GenericInterface3Output is the expected output generated from reflecting on +// GenericInterface3, provided that the receiver is equal to 'r *Receiver' and +// it was generated with the type parameters [string, bool]. +var GenericInterface3Output = `// Method1 is the first method of GenericInterface3. +func (r *Receiver) Method1() (string, bool) { + panic("not implemented") // TODO: Implement +} + +// Method2 is the second method of GenericInterface3. +func (r *Receiver) Method2(_ string, _ bool) { + panic("not implemented") // TODO: Implement +} + +// Method3 is the third method of GenericInterface3. +func (r *Receiver) Method3(_ string) bool { + panic("not implemented") // TODO: Implement +} + +` From e36c5964609079fd4a2b06f9991e1f1f7ca6aa42 Mon Sep 17 00:00:00 2001 From: Paddy Carver Date: Sun, 18 Dec 2022 20:24:52 -0800 Subject: [PATCH 02/20] Remove some debugging changes. --- impl.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/impl.go b/impl.go index 307aac3..0afe6c0 100644 --- a/impl.go +++ b/impl.go @@ -218,12 +218,8 @@ func typeSpec(path, id string, typeParams []string, srcDir string) (Pkg, Spec, e } for _, decl := range f.Decls { - genDecl, ok := decl.(*ast.GenDecl) - if !ok { - continue - } - decl := genDecl - if decl.Tok != token.TYPE { + decl, ok := decl.(*ast.GenDecl) + if !ok || decl.Tok != token.TYPE { continue } for _, spec := range decl.Specs { From 43f209f7b55985092adbc97a386f223b42f2f7f3 Mon Sep 17 00:00:00 2001 From: Paddy Carver Date: Sun, 1 Jan 2023 19:25:57 -0800 Subject: [PATCH 03/20] Revert go.mod to 1.14. --- go.mod | 7 +------ go.sum | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 14b0012..a188163 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,5 @@ module github.com/josharian/impl -go 1.19 +go 1.14 require golang.org/x/tools v0.4.0 - -require ( - golang.org/x/mod v0.7.0 // indirect - golang.org/x/sys v0.3.0 // indirect -) diff --git a/go.sum b/go.sum index 372a637..fb339e5 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,34 @@ +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4= golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From cff83330a5c5bf81359fd5017f68952ddf4a13bb Mon Sep 17 00:00:00 2001 From: Paddy Carver Date: Sun, 1 Jan 2023 19:30:38 -0800 Subject: [PATCH 04/20] Set go.mod back to 1.18. Turns out we need to use 1.18 because we have generics in our testdata package now, meaning `go test` won't run unless our go.mod is set to 1.18. --- go.mod | 7 ++++++- go.sum | 28 ---------------------------- 2 files changed, 6 insertions(+), 29 deletions(-) diff --git a/go.mod b/go.mod index a188163..ec0b262 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,10 @@ module github.com/josharian/impl -go 1.14 +go 1.18 require golang.org/x/tools v0.4.0 + +require ( + golang.org/x/mod v0.7.0 // indirect + golang.org/x/sys v0.3.0 // indirect +) diff --git a/go.sum b/go.sum index fb339e5..372a637 100644 --- a/go.sum +++ b/go.sum @@ -1,34 +1,6 @@ -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4= golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= From bc0371701faf76c292b2e585e65c9e1fa96cc16c Mon Sep 17 00:00:00 2001 From: Paddy Carver Date: Sun, 1 Jan 2023 19:56:08 -0800 Subject: [PATCH 05/20] Document, simplify, and test parseTypeParams. Add a comment documenting the purpose of parseTypeParams. Add tests exercising parseTypeParams. Simplify parseTypeParams by using strings.Cut instead of index math. --- impl.go | 37 ++++++++++++++++++++++++++----------- impl_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 11 deletions(-) diff --git a/impl.go b/impl.go index 0afe6c0..88a45c2 100644 --- a/impl.go +++ b/impl.go @@ -26,26 +26,41 @@ var ( flagRecvPkg = flag.String("recvpkg", "", "package name of the receiver") ) +// parseTypeParams parses the type parameters from a generic type, returning +// the type and its parameters. +// +// For example, the input "foo[Bar, Baz]" would return "foo", []string{"Bar", "Baz"} +// +// The input does not need to be a generic type; if a type with no type +// parameters is passed in, the type will be returned with no parameters. +// +// For example, the input "foo" would return "foo", []string{} func parseTypeParams(in string) (string, []string, error) { - firstOpenBracket := strings.Index(in, "[") - if firstOpenBracket < 0 { - return in, []string{}, nil - } - // there are type parameters in our interface - id := in[:firstOpenBracket] - firstCloseBracket := strings.LastIndex(in, "]") - if firstCloseBracket < 0 { + id, rest, ok := strings.Cut(in, "[") + if !ok { + // no [ found means this isn't a generic type + // just return the input + return in, nil, nil + } + + // we found a [, there should be type parameters in our interface + paramsString, rest, ok := strings.Cut(rest, "]") + if !ok { // make sure we're closing our list of type parameters return "", nil, fmt.Errorf("invalid interface name (cannot have [ without ]): %s", in) } - if firstCloseBracket != len(in)-1 { + if rest != "" { // make sure the first close bracket is actually the last character of the interface name return "", nil, fmt.Errorf("invalid interface name (cannot have ] anywhere except the last character): %s", in) } - params := strings.Split(in[firstOpenBracket+1:firstCloseBracket], ",") + params := strings.Split(paramsString, ",") typeParams := make([]string, 0, len(params)) for _, param := range params { - typeParams = append(typeParams, strings.TrimSpace(param)) + trimmed := strings.TrimSpace(param) + if trimmed == "" { + continue + } + typeParams = append(typeParams, trimmed) } if len(typeParams) < 1 { // make sure if we're declaring type parameters, we declare at least one diff --git a/impl_test.go b/impl_test.go index 0a146b7..b2302b1 100644 --- a/impl_test.go +++ b/impl_test.go @@ -755,3 +755,51 @@ func TestStubGenerationForRepeatedName(t *testing.T) { }) } } + +func TestParseTypeParams(t *testing.T) { + t.Parallel() + + cases := []struct { + desc string + input string + wantID string + wantParams []string + wantErr bool + }{ + {desc: "non-generic type", input: "Reader", wantID: "Reader"}, + {desc: "one type param", input: "Reader[Foo]", wantID: "Reader", wantParams: []string{"Foo"}}, + {desc: "two type params", input: "Reader[Foo, Bar]", wantID: "Reader", wantParams: []string{"Foo", "Bar"}}, + {desc: "three type params", input: "Reader[Foo, Bar, Baz]", wantID: "Reader", wantParams: []string{"Foo", "Bar", "Baz"}}, + {desc: "no spaces", input: "Reader[Foo,Bar]", wantID: "Reader", wantParams: []string{"Foo", "Bar"}}, + {desc: "unclosed brackets", input: "Reader[Foo", wantErr: true}, + {desc: "no params", input: "Reader[]", wantErr: true}, + {desc: "space-only params", input: "Reader[ ]", wantErr: true}, + {desc: "multiple space-only params", input: "Reader[ , , ]", wantErr: true}, + {desc: "characters after bracket", input: "Reader[Foo]Bar", wantErr: true}, + } + for _, tt := range cases { + tt := tt + t.Run(tt.desc, func(t *testing.T) { + t.Parallel() + + id, params, err := parseTypeParams(tt.input) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("unexpected error: %s", err) + } + if id != tt.wantID { + t.Errorf("wanted ID %q, got %q", tt.wantID, id) + } + if len(params) != len(tt.wantParams) { + t.Errorf("wanted %d params, got %d: %v", len(tt.wantParams), len(params), params) + } + for pos, param := range params { + if param != tt.wantParams[pos] { + t.Errorf("expected param %d to be %q, got %q: %v", pos, tt.wantParams[pos], param, params) + } + } + }) + } +} From baa0f59c3d455f02b2090137b1f094d17866d950 Mon Sep 17 00:00:00 2001 From: Paddy Carver Date: Sun, 1 Jan 2023 23:27:20 -0800 Subject: [PATCH 06/20] Add a Type abstraction, make type parsing more robust. Instead of parsing generic types using string manipulation, parse them using the go/parse package to get the AST. Rather than bespokely handling the AST stuff everywhere, pull it out into a recursive helper function (we need it to be recursive because generics allow for some pretty complex AST constructs). Rather than passing around a type ID and its params everywhere, create a Type struct that contains both. Pull our type param matching out into a helper function. --- impl.go | 409 ++++++++++++++++++++++++++++++++++++--------------- impl_test.go | 126 +++++++++------- 2 files changed, 368 insertions(+), 167 deletions(-) diff --git a/impl.go b/impl.go index 88a45c2..f8e0728 100644 --- a/impl.go +++ b/impl.go @@ -26,47 +26,50 @@ var ( flagRecvPkg = flag.String("recvpkg", "", "package name of the receiver") ) -// parseTypeParams parses the type parameters from a generic type, returning -// the type and its parameters. -// -// For example, the input "foo[Bar, Baz]" would return "foo", []string{"Bar", "Baz"} -// -// The input does not need to be a generic type; if a type with no type -// parameters is passed in, the type will be returned with no parameters. -// -// For example, the input "foo" would return "foo", []string{} -func parseTypeParams(in string) (string, []string, error) { - id, rest, ok := strings.Cut(in, "[") - if !ok { - // no [ found means this isn't a generic type - // just return the input - return in, nil, nil - } +// Type is a parsed type reference. +type Type struct { + // ID is the type's ID or name. For example, in "foo[Bar, Baz]", the ID + // is "foo". + ID string + + // Params are the type's type params. For example, in "foo[Bar, Baz]", + // the Params are []string{"Bar", "Baz"}. + // + // Params never list the type of the "name type" construction of type + // params used when defining a generic type. They will always be just + // the filling type, as seen when using a generic type. + // + // Params will always be the type parameters only for the top-level + // type; if the params themselves have type parameters, they will + // remain joined to the type name. So "foo[Bar, Baz[Quux]]" will be + // returned as {ID: "foo", Params: []string{"Bar", "Baz[Quux]"}} + Params []string +} - // we found a [, there should be type parameters in our interface - paramsString, rest, ok := strings.Cut(rest, "]") - if !ok { - // make sure we're closing our list of type parameters - return "", nil, fmt.Errorf("invalid interface name (cannot have [ without ]): %s", in) - } - if rest != "" { - // make sure the first close bracket is actually the last character of the interface name - return "", nil, fmt.Errorf("invalid interface name (cannot have ] anywhere except the last character): %s", in) - } - params := strings.Split(paramsString, ",") - typeParams := make([]string, 0, len(params)) - for _, param := range params { - trimmed := strings.TrimSpace(param) - if trimmed == "" { - continue - } - typeParams = append(typeParams, trimmed) - } - if len(typeParams) < 1 { - // make sure if we're declaring type parameters, we declare at least one - return "", nil, fmt.Errorf("invalid interface name (cannot have empty type parameters): %s", in) +// String constructs a reference to the Type. For example: +// Type{ID: "Foo", Params{{ID: "Bar"}, {ID: "Baz", Params: {{ID: "[]Quux"}}}} +// would yield +// Foo[Bar, Baz[[]Quux]] +func (t Type) String() string { + var res strings.Builder + res.WriteString(t.ID) + if len(t.Params) < 1 { + return res.String() + } + res.WriteString("[") + res.WriteString(strings.Join(t.Params, ", ")) + res.WriteString("]") + return res.String() +} + +// parseType parses an interface reference into a Type, allowing us to +// distinguish between the interface's ID or name and its type parameters. +func parseType(in string) (Type, error) { + expr, err := parser.ParseExpr(in) + if err != nil { + return Type{}, err } - return id, typeParams, nil + return typeFromAST(expr) } // findInterface returns the import path and identifier of an interface. @@ -77,43 +80,48 @@ func parseTypeParams(in string) (string, []string, error) { // If an unqualified interface such as "UserDefinedInterface" is given, then // the interface definition is presumed to be in the package within srcDir and // findInterface returns "", "UserDefinedInterface". -func findInterface(iface string, srcDir string) (path string, id string, typeParams []string, err error) { - if len(strings.Fields(iface)) != 1 && !strings.Contains(iface, "[") { - return "", "", nil, fmt.Errorf("couldn't parse interface: %s", iface) +// +// The typeParams return value will be populated for generic types. For example, +// given "foo[Bar, Baz]", the id return value will be "foo", and typeParams will +// be []string{"Bar", "Baz"}. The types of the type parameters should not be +// included; "foo[Bar any, Baz io.Reader]" is invalid. +func findInterface(input string, srcDir string) (path string, iface Type, err error) { + if len(strings.Fields(input)) != 1 && !strings.Contains(input, "[") { + return "", Type{}, fmt.Errorf("couldn't parse interface: %s", input) } srcPath := filepath.Join(srcDir, "__go_impl__.go") - if slash := strings.LastIndex(iface, "/"); slash > -1 { + if slash := strings.LastIndex(input, "/"); slash > -1 { // package path provided - dot := strings.LastIndex(iface, ".") + dot := strings.LastIndex(input, ".") // make sure iface does not end with "/" (e.g. reject net/http/) - if slash+1 == len(iface) { - return "", "", nil, fmt.Errorf("interface name cannot end with a '/' character: %s", iface) + if slash+1 == len(input) { + return "", Type{}, fmt.Errorf("interface name cannot end with a '/' character: %s", input) } // make sure iface does not end with "." (e.g. reject net/http.) - if dot+1 == len(iface) { - return "", "", nil, fmt.Errorf("interface name cannot end with a '.' character: %s", iface) + if dot+1 == len(input) { + return "", Type{}, fmt.Errorf("interface name cannot end with a '.' character: %s", input) } // make sure iface has at least one "." after "/" (e.g. reject net/http/httputil) - if strings.Count(iface[slash:], ".") == 0 { - return "", "", nil, fmt.Errorf("invalid interface name: %s", iface) + if strings.Count(input[slash:], ".") == 0 { + return "", Type{}, fmt.Errorf("invalid interface name: %s", input) } - path = iface[:dot] - id = iface[dot+1:] - id, typeParams, err = parseTypeParams(id) + path = input[:dot] + id := input[dot+1:] + iface, err = parseType(id) if err != nil { - return "", "", nil, err + return "", Type{}, err } - return path, id, typeParams, nil + return path, iface, nil } - src := []byte("package hack\n" + "var i " + iface) + src := []byte("package hack\n" + "var i " + input) // If we couldn't determine the import path, goimports will // auto fix the import path. imp, err := imports.Process(srcPath, src, nil) if err != nil { - return "", "", nil, fmt.Errorf("couldn't parse interface: %s", iface) + return "", Type{}, fmt.Errorf("couldn't parse interface: %s", input) } // imp should now contain an appropriate import. @@ -124,10 +132,10 @@ func findInterface(iface string, srcDir string) (path string, id string, typePar panic(err) } - qualified := strings.Contains(iface, ".") + qualified := strings.Contains(input, ".") if len(f.Imports) == 0 && qualified { - return "", "", nil, fmt.Errorf("unrecognized interface: %s", iface) + return "", Type{}, fmt.Errorf("unrecognized interface: %s", input) } if !qualified { @@ -138,22 +146,8 @@ func findInterface(iface string, srcDir string) (path string, id string, typePar // var i Reader decl := f.Decls[0].(*ast.GenDecl) // var i io.Reader spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader - if indxExpr, ok := spec.Type.(*ast.IndexExpr); ok { - // a generic type with one type parameter shows up as an IndexExpr - id = indxExpr.X.(*ast.Ident).Name - typeParams = append(typeParams, indxExpr.Index.(*ast.Ident).Name) - } else if indxListExpr, ok := spec.Type.(*ast.IndexListExpr); ok { - // a generic type with multiple type parameters shows up as an IndexListExpr - id = indxListExpr.X.(*ast.Ident).Name - for _, typeParam := range indxListExpr.Indices { - typeParams = append(typeParams, typeParam.(*ast.Ident).Name) - } - } else { - sel := spec.Type.(*ast.Ident) - id = sel.Name // Reader - } - - return path, id, typeParams, nil + iface, err = typeFromAST(spec.Type) + return path, iface, err } // If qualified, the code looks like: @@ -172,22 +166,192 @@ func findInterface(iface string, srcDir string) (path string, id string, typePar } decl := f.Decls[1].(*ast.GenDecl) // var i io.Reader spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader - if indxExpr, ok := spec.Type.(*ast.IndexExpr); ok { - // a generic type with one type parameter shows up as an IndexExpr - id = indxExpr.X.(*ast.SelectorExpr).Sel.Name - typeParams = append(typeParams, indxExpr.Index.(*ast.Ident).Name) - } else if indxListExpr, ok := spec.Type.(*ast.IndexListExpr); ok { + iface, err = typeFromAST(spec.Type) + return path, iface, err +} + +func typeFromAST(in ast.Expr) (Type, error) { + switch specType := in.(type) { + case *ast.Ident: + // a standalone identifier (Reader) shows up as an Ident + return Type{ID: specType.Name}, nil + case *ast.SelectorExpr: + // an identifier in a different package (io.Reader) shows up as a SelectorExpr + // we need to pull the name out + return Type{ID: specType.Sel.Name}, nil + case *ast.StarExpr: + // pointer identifiers (*Reader) show up as a StarExpr + // we need to pull the name out and prefix it with a * + typ, err := typeFromAST(specType.X) + if err != nil { + return Type{}, err + } + typ.ID = "*" + typ.ID + return typ, nil + case *ast.ArrayType: + // slices and arrays ([]Reader) show up as an ArrayType + typ, err := typeFromAST(specType.Elt) + if err != nil { + return Type{}, err + } + prefix := "[" + if specType.Len != nil { + prefix += specType.Len.(*ast.BasicLit).Value + } + prefix += "]" + typ.ID = prefix + typ.ID + return typ, nil + case *ast.MapType: + // maps (map[string]Reader) show up as a MapType + key, err := typeFromAST(specType.Key) + if err != nil { + return Type{}, err + } + value, err := typeFromAST(specType.Value) + if err != nil { + return Type{}, err + } + return Type{ + ID: "map[" + key.String() + "]" + value.String(), + }, nil + case *ast.FuncType: + // funcs (func() Reader) show up as a FuncType + // NOTE: we don't actually parse out the type params of a FuncType. + // This should be okay, because we really only care about + // parsing out the type params when parsing interface + // identifiers. And FuncTypes never signify an interface + // identifier, they're just an argument to it or a type param + // of it. + // We don't parse them out anyways like we do for everything + // else because funcs, alone, are pretty weird in how they use + // generics. + // For everything else, it's identifier[params]. + // For funcs, the params get stuck in the middle of the identifier: + // func Foo[Param1, Param2](context.Context, Param1) Param2 + // We're gonna need to complicate everything to support that + // construction, and we don't actually need the deconstructed + // bits, so we're just... not going to deconstruct it at all. + var res strings.Builder + res.WriteString("func") + if specType.TypeParams != nil && len(specType.TypeParams.List) > 0 { + res.WriteString("[") + paramList, err := buildFuncParamList(specType.TypeParams.List) + if err != nil { + return Type{}, err + } + res.WriteString(paramList) + res.WriteString("]") + } + res.WriteString("(") + if specType.Params != nil { + paramList, err := buildFuncParamList(specType.Params.List) + if err != nil { + return Type{}, err + } + res.WriteString(paramList) + } + res.WriteString(")") + if specType.Results != nil && len(specType.Results.List) > 0 { + res.WriteString(" ") + if len(specType.Results.List) > 1 { + res.WriteString("(") + } + paramList, err := buildFuncParamList(specType.Results.List) + if err != nil { + return Type{}, err + } + res.WriteString(paramList) + if len(specType.Results.List) > 1 { + res.WriteString(")") + } + } + return Type{ID: res.String()}, nil + case *ast.ChanType: + var res strings.Builder + // channels (chan Reader) show up as a ChanType + // we need to be careful to preserve send/receive semantics + if specType.Dir&ast.SEND == 0 { + // this is a receive-only channel, write the arrow before the chan keyword + res.WriteString("<-") + } + res.WriteString("chan") + if specType.Dir&ast.RECV == 0 { + // this is a send-only channel, write the arrow after the chan keyword + res.WriteString("<-") + } + res.WriteString(" ") + valType, err := typeFromAST(specType.Value) + if err != nil { + return Type{}, err + } + valType.ID = res.String() + valType.ID + return valType, nil + case *ast.IndexExpr: + // a generic type with one type parameter (Reader[Foo]) shows up as an IndexExpr + id, err := typeFromAST(specType.X) + if err != nil { + return Type{}, err + } + if len(id.Params) > 0 { + return Type{}, fmt.Errorf("got type parameters for a type ID, which is very confusing: %s", id.String()) + } + param, err := typeFromAST(specType.Index) + if err != nil { + return Type{}, err + } + return Type{ + ID: id.ID, + Params: []string{param.String()}, + }, nil + case *ast.IndexListExpr: // a generic type with multiple type parameters shows up as an IndexListExpr - id = indxListExpr.X.(*ast.SelectorExpr).Sel.Name - for _, typeParam := range indxListExpr.Indices { - typeParams = append(typeParams, typeParam.(*ast.Ident).Name) + id, err := typeFromAST(specType.X) + if err != nil { + return Type{}, err } - } else { - sel := spec.Type.(*ast.SelectorExpr) // io.Reader - id = sel.Sel.Name // Reader + if len(id.Params) > 0 { + return Type{}, fmt.Errorf("got type parameters for a type ID, which is very confusing: %s", id.String()) + } + res := Type{ + ID: specType.X.(*ast.Ident).Name, + } + for _, typeParam := range specType.Indices { + param, err := typeFromAST(typeParam) + if err != nil { + return Type{}, err + } + res.Params = append(res.Params, param.String()) + } + return res, nil } + return Type{}, fmt.Errorf("unexpected AST type %T", in) +} - return path, id, typeParams, nil +// buildFuncParamList returns a string representation of a list of function +// params (type params, function arguments, returns) given an []*ast.Field +// for those things. +func buildFuncParamList(list []*ast.Field) (string, error) { + var res strings.Builder + for pos, field := range list { + for namePos, name := range field.Names { + res.WriteString(name.Name) + if namePos+1 < len(field.Names) { + res.WriteString(", ") + } + } + if len(field.Names) > 0 { + res.WriteString(" ") + } + fieldType, err := typeFromAST(field.Type) + if err != nil { + return "", err + } + res.WriteString(fieldType.String()) + if pos+1 < len(list) { + res.WriteString(", ") + } + } + return res.String(), nil } // Pkg is a parsed build.Package. @@ -206,7 +370,7 @@ type Spec struct { } // typeSpec locates the *ast.TypeSpec for type id in the import path. -func typeSpec(path, id string, typeParams []string, srcDir string) (Pkg, Spec, error) { +func typeSpec(path string, typ Type, srcDir string) (Pkg, Spec, error) { var pkg *build.Package var err error @@ -239,34 +403,47 @@ func typeSpec(path, id string, typeParams []string, srcDir string) (Pkg, Spec, e } for _, spec := range decl.Specs { spec := spec.(*ast.TypeSpec) - if spec.Name.Name != id { + if spec.Name.Name != typ.ID { continue } - tParams := make(map[string]string, len(typeParams)) - if spec.TypeParams != nil { - var specParamNames []string - for _, typeParam := range spec.TypeParams.List { - for _, name := range typeParam.Names { - if name == nil { - continue - } - specParamNames = append(specParamNames, name.Name) - } - } - if len(specParamNames) != len(typeParams) { - continue - } - for pos, specParamName := range specParamNames { - tParams[specParamName] = typeParams[pos] - } + typeParams, ok := matchTypeParams(spec, typ.Params) + if !ok { + continue } p := Pkg{Package: pkg, FileSet: fset} - s := Spec{TypeSpec: spec, TypeParams: tParams} + s := Spec{TypeSpec: spec, TypeParams: typeParams} return p, s, nil } } } - return Pkg{}, Spec{}, fmt.Errorf("type %s not found in %s", id, path) + return Pkg{}, Spec{}, fmt.Errorf("type %s not found in %s", typ.ID, path) +} + +// matchTypeParams returns a map of type parameters from a parsed interface +// definition and the types that fill them from the user's specified type +// info. If the passed params can't be used to fill the type parameters on the +// passed type, a nil map and false are returned. No type checking is done, +// only that there are sufficient types to match. +func matchTypeParams(spec *ast.TypeSpec, params []string) (map[string]string, bool) { + res := make(map[string]string, len(params)) + if spec.TypeParams != nil { + var specParamNames []string + for _, typeParam := range spec.TypeParams.List { + for _, name := range typeParam.Names { + if name == nil { + continue + } + specParamNames = append(specParamNames, name.Name) + } + } + if len(specParamNames) != len(params) { + return nil, false + } + for pos, specParamName := range specParamNames { + res[specParamName] = params[pos] + } + } + return res, true } // gofmt pretty-prints e. @@ -302,13 +479,13 @@ func (p Pkg) fullType(e ast.Expr) string { return p.gofmt(e) } -func (p Pkg) params(field *ast.Field, genericTypes map[string]string) []Param { +func (p Pkg) params(field *ast.Field, typeParams map[string]string) []Param { var params []Param var typ string ident, ok := field.Type.(*ast.Ident) if !ok || ident == nil { typ = p.fullType(field.Type) - } else if genType, ok := genericTypes[ident.Name]; ok { + } else if genType, ok := typeParams[ident.Name]; ok { typ = genType } else { typ = p.fullType(field.Type) @@ -351,12 +528,12 @@ const ( WithoutComments EmitComments = false ) -func (p Pkg) funcsig(f *ast.Field, genericParams map[string]string, cmap ast.CommentMap, comments EmitComments) Func { +func (p Pkg) funcsig(f *ast.Field, typeParams map[string]string, cmap ast.CommentMap, comments EmitComments) Func { fn := Func{Name: f.Names[0].Name} typ := f.Type.(*ast.FuncType) if typ.Params != nil { for _, field := range typ.Params.List { - for _, param := range p.params(field, genericParams) { + for _, param := range p.params(field, typeParams) { // only for method parameters: // assign a blank identifier "_" to an anonymous parameter if param.Name == "" { @@ -368,7 +545,7 @@ func (p Pkg) funcsig(f *ast.Field, genericParams map[string]string, cmap ast.Com } if typ.Results != nil { for _, field := range typ.Results.List { - fn.Res = append(fn.Res, p.params(field, genericParams)...) + fn.Res = append(fn.Res, p.params(field, typeParams)...) } } if comments == WithComments && f.Doc != nil { @@ -393,13 +570,13 @@ func funcs(iface, srcDir, recvPkg string, comments EmitComments) ([]Func, error) } // Locate the interface. - path, id, typeParams, err := findInterface(iface, srcDir) + path, typ, err := findInterface(iface, srcDir) if err != nil { return nil, err } // Parse the package and find the interface declaration. - p, spec, err := typeSpec(path, id, typeParams, srcDir) + p, spec, err := typeSpec(path, typ, srcDir) if err != nil { return nil, fmt.Errorf("interface %s not found: %s", iface, err) } @@ -557,7 +734,7 @@ to prevent shell globbing. recvs := strings.Fields(recv) receiver := recvs[len(recvs)-1] // note that this correctly handles "s *Struct" and "*Struct" receiver = strings.TrimPrefix(receiver, "*") - pkg, _, err := typeSpec("", receiver, nil, *flagSrcDir) + pkg, _, err := typeSpec("", Type{ID: receiver}, *flagSrcDir) if err == nil { recvPkg = pkg.Package.Name } diff --git a/impl_test.go b/impl_test.go index b2302b1..2cbbf6a 100644 --- a/impl_test.go +++ b/impl_test.go @@ -20,45 +20,45 @@ func (b errBool) String() string { func TestFindInterface(t *testing.T) { t.Parallel() cases := []struct { - iface string - path string - id string - typeParams []string - wantErr bool + input string + path string + typ Type + wantErr bool }{ - {iface: "net.Conn", path: "net", id: "Conn"}, - {iface: "http.ResponseWriter", path: "net/http", id: "ResponseWriter"}, - {iface: "net.Tennis", wantErr: true}, - {iface: "a + b", wantErr: true}, - {iface: "a/b/c/", wantErr: true}, - {iface: "a/b/c/pkg", wantErr: true}, - {iface: "a/b/c/pkg.", wantErr: true}, - {iface: "a/b/c/pkg.Typ", path: "a/b/c/pkg", id: "Typ"}, - {iface: "gopkg.in/yaml.v2.Unmarshaler", path: "gopkg.in/yaml.v2", id: "Unmarshaler"}, - {iface: "github.com/josharian/impl/testdata.GenericInterface1[string]", path: "github.com/josharian/impl/testdata", id: "GenericInterface1", typeParams: []string{"string"}}, + {input: "net.Conn", path: "net", typ: Type{ID: "Conn"}}, + {input: "http.ResponseWriter", path: "net/http", typ: Type{ID: "ResponseWriter"}}, + {input: "net.Tennis", wantErr: true}, + {input: "a + b", wantErr: true}, + {input: "a/b/c/", wantErr: true}, + {input: "a/b/c/pkg", wantErr: true}, + {input: "a/b/c/pkg.", wantErr: true}, + {input: "a/b/c/pkg.Typ", path: "a/b/c/pkg", typ: Type{ID: "Typ"}}, + {input: "gopkg.in/yaml.v2.Unmarshaler", path: "gopkg.in/yaml.v2", typ: Type{ID: "Unmarshaler"}}, + {input: "github.com/josharian/impl/testdata.GenericInterface1[string]", path: "github.com/josharian/impl/testdata", typ: Type{ID: "GenericInterface1", Params: []string{"string"}}}, + {input: "github.com/josharian/impl/testdata.GenericInterface1[*string]", path: "github.com/josharian/impl/testdata", typ: Type{ID: "GenericInterface1", Params: []string{"*string"}}}, } for _, tt := range cases { tt := tt - t.Run(tt.iface, func(t *testing.T) { + t.Run(tt.input, func(t *testing.T) { t.Parallel() - path, id, typeParams, err := findInterface(tt.iface, ".") + path, typ, err := findInterface(tt.input, ".") gotErr := err != nil if tt.wantErr != gotErr { - t.Fatalf("findInterface(%q).err=%v want %s", tt.iface, err, errBool(tt.wantErr)) + t.Fatalf("findInterface(%q).err=%v want %s", tt.input, err, errBool(tt.wantErr)) } if tt.path != path { - t.Errorf("findInterface(%q).path=%q want %q", tt.iface, path, tt.path) + t.Errorf("findInterface(%q).path=%q want %q", tt.input, path, tt.path) } - if tt.id != id { - t.Errorf("findInterface(%q).id=%q want %q", tt.iface, id, tt.id) + if tt.typ.ID != typ.ID { + t.Errorf("findInterface(%q).id=%q want %q", tt.input, typ.ID, tt.typ.ID) } - if len(tt.typeParams) != len(typeParams) { - t.Errorf("findInterface(%q).len(typeParams)=%d want %d", tt.iface, len(typeParams), len(tt.typeParams)) + if len(tt.typ.Params) != len(typ.Params) { + t.Errorf("findInterface(%q).len(typeParams)=%d want %d", tt.input, len(typ.Params), len(tt.typ.Params)) } - for pos, v := range tt.typeParams { - if v != typeParams[pos] { - t.Errorf("findInterface(%q).typeParams[%d]=%q, want %q", tt.iface, pos, typeParams[pos], v) + for pos, v := range tt.typ.Params { + if v != typ.Params[pos] { + t.Errorf("findInterface(%q).typeParams[%d]=%q, want %q", tt.input, pos, typ.Params[pos], v) } } }) @@ -69,26 +69,26 @@ func TestTypeSpec(t *testing.T) { // For now, just test whether we can find the interface. cases := []struct { path string - id string + typ Type wantErr bool }{ - {path: "net", id: "Conn"}, - {path: "net", id: "Con", wantErr: true}, + {path: "net", typ: Type{ID: "Conn"}}, + {path: "net", typ: Type{ID: "Con"}, wantErr: true}, } for _, tt := range cases { - p, spec, err := typeSpec(tt.path, tt.id, nil, "") + p, spec, err := typeSpec(tt.path, tt.typ, "") gotErr := err != nil if tt.wantErr != gotErr { - t.Errorf("typeSpec(%q, %q).err=%v want %s", tt.path, tt.id, err, errBool(tt.wantErr)) + t.Errorf("typeSpec(%q, %q).err=%v want %s", tt.path, tt.typ, err, errBool(tt.wantErr)) continue } if err == nil { if reflect.DeepEqual(p, Pkg{}) { - t.Errorf("typeSpec(%q, %q).pkg=Pkg{} want non-nil", tt.path, tt.id) + t.Errorf("typeSpec(%q, %q).pkg=Pkg{} want non-nil", tt.path, tt.typ) } if reflect.DeepEqual(spec, Spec{}) { - t.Errorf("typeSpec(%q, %q).spec=Spec{} want non-nil", tt.path, tt.id) + t.Errorf("typeSpec(%q, %q).spec=Spec{} want non-nil", tt.path, tt.typ) } } } @@ -760,44 +760,68 @@ func TestParseTypeParams(t *testing.T) { t.Parallel() cases := []struct { - desc string - input string - wantID string - wantParams []string - wantErr bool + desc string + input string + want Type + wantErr bool }{ - {desc: "non-generic type", input: "Reader", wantID: "Reader"}, - {desc: "one type param", input: "Reader[Foo]", wantID: "Reader", wantParams: []string{"Foo"}}, - {desc: "two type params", input: "Reader[Foo, Bar]", wantID: "Reader", wantParams: []string{"Foo", "Bar"}}, - {desc: "three type params", input: "Reader[Foo, Bar, Baz]", wantID: "Reader", wantParams: []string{"Foo", "Bar", "Baz"}}, - {desc: "no spaces", input: "Reader[Foo,Bar]", wantID: "Reader", wantParams: []string{"Foo", "Bar"}}, + {desc: "non-generic type", input: "Reader", want: Type{ID: "Reader"}}, + {desc: "one type param", input: "Reader[Foo]", want: Type{ID: "Reader", Params: []string{"Foo"}}}, + {desc: "two type params", input: "Reader[Foo, Bar]", want: Type{ID: "Reader", Params: []string{"Foo", "Bar"}}}, + {desc: "three type params", input: "Reader[Foo, Bar, Baz]", want: Type{ID: "Reader", Params: []string{"Foo", "Bar", "Baz"}}}, + {desc: "no spaces", input: "Reader[Foo,Bar]", want: Type{ID: "Reader", Params: []string{"Foo", "Bar"}}}, {desc: "unclosed brackets", input: "Reader[Foo", wantErr: true}, {desc: "no params", input: "Reader[]", wantErr: true}, {desc: "space-only params", input: "Reader[ ]", wantErr: true}, {desc: "multiple space-only params", input: "Reader[ , , ]", wantErr: true}, {desc: "characters after bracket", input: "Reader[Foo]Bar", wantErr: true}, + {desc: "qualified generic type", input: "io.Reader[Foo]", want: Type{ID: "Reader", Params: []string{"Foo"}}}, + {desc: "qualified generic param", input: "Reader[io.Reader]", want: Type{ID: "Reader", Params: []string{"Reader"}}}, + {desc: "pointer qualified generic param", input: "Reader[*io.Reader]", want: Type{ID: "Reader", Params: []string{"*Reader"}}}, + {desc: "map generic param", input: "Reader[map[string]string]", want: Type{ID: "Reader", Params: []string{"map[string]string"}}}, + {desc: "pointer map generic param", input: "Reader[*map[string]string]", want: Type{ID: "Reader", Params: []string{"*map[string]string"}}}, + {desc: "pointer key map generic param", input: "Reader[map[*string]string]", want: Type{ID: "Reader", Params: []string{"map[*string]string"}}}, + {desc: "pointer value map generic param", input: "Reader[map[string]*string]", want: Type{ID: "Reader", Params: []string{"map[string]*string"}}}, + {desc: "slice generic param", input: "Reader[[]string]", want: Type{ID: "Reader", Params: []string{"[]string"}}}, + {desc: "pointer slice generic param", input: "Reader[*[]string]", want: Type{ID: "Reader", Params: []string{"*[]string"}}}, + {desc: "pointer slice value generic param", input: "Reader[[]*string]", want: Type{ID: "Reader", Params: []string{"[]*string"}}}, + {desc: "array generic param", input: "Reader[[1]string]", want: Type{ID: "Reader", Params: []string{"[1]string"}}}, + {desc: "pointer array generic param", input: "Reader[*[1]string]", want: Type{ID: "Reader", Params: []string{"*[1]string"}}}, + {desc: "pointer array value generic param", input: "Reader[[1]*string]", want: Type{ID: "Reader", Params: []string{"[1]*string"}}}, + {desc: "chan generic param", input: "Reader[chan error]", want: Type{ID: "Reader", Params: []string{"chan error"}}}, + {desc: "receiver chan generic param", input: "Reader[<-chan error]", want: Type{ID: "Reader", Params: []string{"<-chan error"}}}, + {desc: "send chan generic param", input: "Reader[chan<- error]", want: Type{ID: "Reader", Params: []string{"chan<- error"}}}, + {desc: "pointer chan generic param", input: "Reader[*chan error]", want: Type{ID: "Reader", Params: []string{"*chan error"}}}, + {desc: "func generic param", input: "Reader[func() string]", want: Type{ID: "Reader", Params: []string{"func() string"}}}, + {desc: "one arg func generic param", input: "Reader[func(a int) string]", want: Type{ID: "Reader", Params: []string{"func(a int) string"}}}, + {desc: "two arg one type func generic param", input: "Reader[func(a, b int) string]", want: Type{ID: "Reader", Params: []string{"func(a, b int) string"}}}, + {desc: "three arg one type func generic param", input: "Reader[func(a, b, c int) string]", want: Type{ID: "Reader", Params: []string{"func(a, b, c int) string"}}}, + {desc: "three arg two types func generic param", input: "Reader[func(a, b string, c int) string]", want: Type{ID: "Reader", Params: []string{"func(a, b string, c int) string"}}}, + {desc: "three arg three types func generic param", input: "Reader[func(a bool, b string, c int) string]", want: Type{ID: "Reader", Params: []string{"func(a bool, b string, c int) string"}}}, + // don't need support for generics on the function type itself; function types must have no type parameters + // https://cs.opensource.google/go/go/+/master:src/go/parser/parser.go;l=1048;drc=cafb49ac731f862f386862d64b27b8314eeb2909 } for _, tt := range cases { tt := tt t.Run(tt.desc, func(t *testing.T) { t.Parallel() - id, params, err := parseTypeParams(tt.input) + typ, err := parseType(tt.input) if err != nil { if tt.wantErr { return } t.Fatalf("unexpected error: %s", err) } - if id != tt.wantID { - t.Errorf("wanted ID %q, got %q", tt.wantID, id) + if typ.ID != tt.want.ID { + t.Errorf("wanted ID %q, got %q", tt.want.ID, typ.ID) } - if len(params) != len(tt.wantParams) { - t.Errorf("wanted %d params, got %d: %v", len(tt.wantParams), len(params), params) + if len(typ.Params) != len(tt.want.Params) { + t.Errorf("wanted %d params, got %d: %v", len(tt.want.Params), len(typ.Params), typ.Params) } - for pos, param := range params { - if param != tt.wantParams[pos] { - t.Errorf("expected param %d to be %q, got %q: %v", pos, tt.wantParams[pos], param, params) + for pos, param := range typ.Params { + if param != tt.want.Params[pos] { + t.Errorf("expected param %d to be %q, got %q: %v", pos, tt.want.Params[pos], param, typ.Params) } } }) From 2d3f13f32b64ba08082efc1ca9203f08a7296270 Mon Sep 17 00:00:00 2001 From: Paddy Carver Date: Sun, 1 Jan 2023 23:35:17 -0800 Subject: [PATCH 07/20] Fix outdated comment. --- impl.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/impl.go b/impl.go index f8e0728..6f5c3b2 100644 --- a/impl.go +++ b/impl.go @@ -47,7 +47,7 @@ type Type struct { } // String constructs a reference to the Type. For example: -// Type{ID: "Foo", Params{{ID: "Bar"}, {ID: "Baz", Params: {{ID: "[]Quux"}}}} +// Type{ID: "Foo", Params{"Bar", "Baz[[]Quux]"}} // would yield // Foo[Bar, Baz[[]Quux]] func (t Type) String() string { From e5ca4b6cb7c44a344093ce7e022ed891b9410ad4 Mon Sep 17 00:00:00 2001 From: Paddy Carver Date: Sun, 1 Jan 2023 23:38:41 -0800 Subject: [PATCH 08/20] Update oudated comment. --- impl.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/impl.go b/impl.go index 6f5c3b2..9db7187 100644 --- a/impl.go +++ b/impl.go @@ -72,19 +72,19 @@ func parseType(in string) (Type, error) { return typeFromAST(expr) } -// findInterface returns the import path and identifier of an interface. +// findInterface returns the import path and type of an interface. // For example, given "http.ResponseWriter", findInterface returns -// "net/http", "ResponseWriter". +// "net/http", Type{ID: "ResponseWriter"}. // If a fully qualified interface is given, such as "net/http.ResponseWriter", // it simply parses the input. // If an unqualified interface such as "UserDefinedInterface" is given, then // the interface definition is presumed to be in the package within srcDir and -// findInterface returns "", "UserDefinedInterface". +// findInterface returns "", Type{ID: "UserDefinedInterface"}. // -// The typeParams return value will be populated for generic types. For example, -// given "foo[Bar, Baz]", the id return value will be "foo", and typeParams will -// be []string{"Bar", "Baz"}. The types of the type parameters should not be -// included; "foo[Bar any, Baz io.Reader]" is invalid. +// Generic types will have their type params returned in the Params property of +// the Type. Input should always reference generic types with their parameters +// filled, i.e. GenericType[string, bool], as opposed to +// GenericType[A any, B comparable]. func findInterface(input string, srcDir string) (path string, iface Type, err error) { if len(strings.Fields(input)) != 1 && !strings.Contains(input, "[") { return "", Type{}, fmt.Errorf("couldn't parse interface: %s", input) From e9c31bea9b7dea30191d43e1c1df10a76158c3db Mon Sep 17 00:00:00 2001 From: Paddy Carver Date: Sun, 1 Jan 2023 23:47:10 -0800 Subject: [PATCH 09/20] Switch lingering if/elseif chain to a type switch. --- impl.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/impl.go b/impl.go index 9db7187..6acfb95 100644 --- a/impl.go +++ b/impl.go @@ -482,12 +482,14 @@ func (p Pkg) fullType(e ast.Expr) string { func (p Pkg) params(field *ast.Field, typeParams map[string]string) []Param { var params []Param var typ string - ident, ok := field.Type.(*ast.Ident) - if !ok || ident == nil { - typ = p.fullType(field.Type) - } else if genType, ok := typeParams[ident.Name]; ok { - typ = genType - } else { + switch expr := field.Type.(type) { + case *ast.Ident: + if genType, ok := typeParams[expr.Name]; ok { + typ = genType + } else { + typ = p.fullType(field.Type) + } + default: typ = p.fullType(field.Type) } for _, name := range field.Names { @@ -495,7 +497,7 @@ func (p Pkg) params(field *ast.Field, typeParams map[string]string) []Param { } // Handle anonymous params if len(params) == 0 { - params = []Param{Param{Type: typ}} + params = []Param{{Type: typ}} } return params } From 9f0a8ce703853b314e5f20b3bb3929eb03c566ac Mon Sep 17 00:00:00 2001 From: Paddy Carver Date: Mon, 2 Jan 2023 10:54:53 -0800 Subject: [PATCH 10/20] Add test for panic. Add a test for the interface assertion that was failing. --- impl_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/impl_test.go b/impl_test.go index 2cbbf6a..062ffc7 100644 --- a/impl_test.go +++ b/impl_test.go @@ -776,6 +776,7 @@ func TestParseTypeParams(t *testing.T) { {desc: "multiple space-only params", input: "Reader[ , , ]", wantErr: true}, {desc: "characters after bracket", input: "Reader[Foo]Bar", wantErr: true}, {desc: "qualified generic type", input: "io.Reader[Foo]", want: Type{ID: "Reader", Params: []string{"Foo"}}}, + {desc: "qualified generic type with two params", input: "io.Reader[Foo, Bar]", want: Type{ID: "Reader", Params: []string{"Foo", "Bar"}}}, {desc: "qualified generic param", input: "Reader[io.Reader]", want: Type{ID: "Reader", Params: []string{"Reader"}}}, {desc: "pointer qualified generic param", input: "Reader[*io.Reader]", want: Type{ID: "Reader", Params: []string{"*Reader"}}}, {desc: "map generic param", input: "Reader[map[string]string]", want: Type{ID: "Reader", Params: []string{"map[string]string"}}}, From 50818c69297061d4ea3e08236c532a6d6cb9c440 Mon Sep 17 00:00:00 2001 From: Paddy Carver Date: Mon, 2 Jan 2023 10:57:32 -0800 Subject: [PATCH 11/20] Fix panic. Had a lingering type assertion that should have used the ID from the parsed AST type instead. Also noticed that we were losing the package names on the type parameters, which would lead to potentially-subtle incorrect results; we should always be using the package if the user submits it. This requires an update to typeFromAST to preserve the package, and then updating findInterface to strip the package off when dealing with qualified interfaces. --- impl.go | 13 +++++++++---- impl_test.go | 9 +++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/impl.go b/impl.go index 6acfb95..b487718 100644 --- a/impl.go +++ b/impl.go @@ -144,8 +144,8 @@ func findInterface(input string, srcDir string) (path string, iface Type, err er // package hack // // var i Reader - decl := f.Decls[0].(*ast.GenDecl) // var i io.Reader - spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader + decl := f.Decls[0].(*ast.GenDecl) // var i Reader + spec := decl.Specs[0].(*ast.ValueSpec) // i Reader iface, err = typeFromAST(spec.Type) return path, iface, err } @@ -167,6 +167,11 @@ func findInterface(input string, srcDir string) (path string, iface Type, err er decl := f.Decls[1].(*ast.GenDecl) // var i io.Reader spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader iface, err = typeFromAST(spec.Type) + if err != nil { + return path, iface, fmt.Errorf("error parsing type from AST: %w", err) + } + // trim off the package which got smooshed on when resolving the type + _, iface.ID, _ = strings.Cut(iface.ID, ".") return path, iface, err } @@ -178,7 +183,7 @@ func typeFromAST(in ast.Expr) (Type, error) { case *ast.SelectorExpr: // an identifier in a different package (io.Reader) shows up as a SelectorExpr // we need to pull the name out - return Type{ID: specType.Sel.Name}, nil + return Type{ID: specType.X.(*ast.Ident).Name + "." + specType.Sel.Name}, nil case *ast.StarExpr: // pointer identifiers (*Reader) show up as a StarExpr // we need to pull the name out and prefix it with a * @@ -313,7 +318,7 @@ func typeFromAST(in ast.Expr) (Type, error) { return Type{}, fmt.Errorf("got type parameters for a type ID, which is very confusing: %s", id.String()) } res := Type{ - ID: specType.X.(*ast.Ident).Name, + ID: id.ID, } for _, typeParam := range specType.Indices { param, err := typeFromAST(typeParam) diff --git a/impl_test.go b/impl_test.go index 062ffc7..4e809fa 100644 --- a/impl_test.go +++ b/impl_test.go @@ -775,10 +775,11 @@ func TestParseTypeParams(t *testing.T) { {desc: "space-only params", input: "Reader[ ]", wantErr: true}, {desc: "multiple space-only params", input: "Reader[ , , ]", wantErr: true}, {desc: "characters after bracket", input: "Reader[Foo]Bar", wantErr: true}, - {desc: "qualified generic type", input: "io.Reader[Foo]", want: Type{ID: "Reader", Params: []string{"Foo"}}}, - {desc: "qualified generic type with two params", input: "io.Reader[Foo, Bar]", want: Type{ID: "Reader", Params: []string{"Foo", "Bar"}}}, - {desc: "qualified generic param", input: "Reader[io.Reader]", want: Type{ID: "Reader", Params: []string{"Reader"}}}, - {desc: "pointer qualified generic param", input: "Reader[*io.Reader]", want: Type{ID: "Reader", Params: []string{"*Reader"}}}, + {desc: "qualified generic type", input: "io.Reader[Foo]", want: Type{ID: "io.Reader", Params: []string{"Foo"}}}, + {desc: "qualified generic type with two params", input: "io.Reader[Foo, Bar]", want: Type{ID: "io.Reader", Params: []string{"Foo", "Bar"}}}, + {desc: "qualified generic param", input: "Reader[io.Reader]", want: Type{ID: "Reader", Params: []string{"io.Reader"}}}, + {desc: "qualified and unqualified generic param", input: "Reader[io.Reader, string]", want: Type{ID: "Reader", Params: []string{"io.Reader", "string"}}}, + {desc: "pointer qualified generic param", input: "Reader[*io.Reader]", want: Type{ID: "Reader", Params: []string{"*io.Reader"}}}, {desc: "map generic param", input: "Reader[map[string]string]", want: Type{ID: "Reader", Params: []string{"map[string]string"}}}, {desc: "pointer map generic param", input: "Reader[*map[string]string]", want: Type{ID: "Reader", Params: []string{"*map[string]string"}}}, {desc: "pointer key map generic param", input: "Reader[map[*string]string]", want: Type{ID: "Reader", Params: []string{"map[*string]string"}}}, From a52407f7c38ea8f03a82e09a6f9e0ec410d0aaee Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Tue, 16 Jan 2024 09:55:00 -0800 Subject: [PATCH 12/20] go.mod: update all deps --- go.mod | 6 +++--- go.sum | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index ec0b262..faddeb2 100644 --- a/go.mod +++ b/go.mod @@ -2,9 +2,9 @@ module github.com/josharian/impl go 1.18 -require golang.org/x/tools v0.4.0 +require golang.org/x/tools v0.17.0 require ( - golang.org/x/mod v0.7.0 // indirect - golang.org/x/sys v0.3.0 // indirect + golang.org/x/mod v0.14.0 // indirect + golang.org/x/sys v0.16.0 // indirect ) diff --git a/go.sum b/go.sum index 372a637..bfbc802 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,12 @@ golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4= golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= +golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= +golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= From 902a46e173938cceaef8f43391134f9f4beb5232 Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Tue, 16 Jan 2024 09:57:22 -0800 Subject: [PATCH 13/20] rename type.ID to type.Name and let gofumpt scribble on the files --- impl.go | 36 +++++------ impl_test.go | 180 ++++++++++++++++++++++++++------------------------- 2 files changed, 109 insertions(+), 107 deletions(-) diff --git a/impl.go b/impl.go index b487718..8ce5d18 100644 --- a/impl.go +++ b/impl.go @@ -28,9 +28,9 @@ var ( // Type is a parsed type reference. type Type struct { - // ID is the type's ID or name. For example, in "foo[Bar, Baz]", the ID + // Name is the type's name. For example, in "foo[Bar, Baz]", the name // is "foo". - ID string + Name string // Params are the type's type params. For example, in "foo[Bar, Baz]", // the Params are []string{"Bar", "Baz"}. @@ -52,7 +52,7 @@ type Type struct { // Foo[Bar, Baz[[]Quux]] func (t Type) String() string { var res strings.Builder - res.WriteString(t.ID) + res.WriteString(t.Name) if len(t.Params) < 1 { return res.String() } @@ -171,7 +171,7 @@ func findInterface(input string, srcDir string) (path string, iface Type, err er return path, iface, fmt.Errorf("error parsing type from AST: %w", err) } // trim off the package which got smooshed on when resolving the type - _, iface.ID, _ = strings.Cut(iface.ID, ".") + _, iface.Name, _ = strings.Cut(iface.Name, ".") return path, iface, err } @@ -179,11 +179,11 @@ func typeFromAST(in ast.Expr) (Type, error) { switch specType := in.(type) { case *ast.Ident: // a standalone identifier (Reader) shows up as an Ident - return Type{ID: specType.Name}, nil + return Type{Name: specType.Name}, nil case *ast.SelectorExpr: // an identifier in a different package (io.Reader) shows up as a SelectorExpr // we need to pull the name out - return Type{ID: specType.X.(*ast.Ident).Name + "." + specType.Sel.Name}, nil + return Type{Name: specType.X.(*ast.Ident).Name + "." + specType.Sel.Name}, nil case *ast.StarExpr: // pointer identifiers (*Reader) show up as a StarExpr // we need to pull the name out and prefix it with a * @@ -191,7 +191,7 @@ func typeFromAST(in ast.Expr) (Type, error) { if err != nil { return Type{}, err } - typ.ID = "*" + typ.ID + typ.Name = "*" + typ.Name return typ, nil case *ast.ArrayType: // slices and arrays ([]Reader) show up as an ArrayType @@ -204,7 +204,7 @@ func typeFromAST(in ast.Expr) (Type, error) { prefix += specType.Len.(*ast.BasicLit).Value } prefix += "]" - typ.ID = prefix + typ.ID + typ.Name = prefix + typ.Name return typ, nil case *ast.MapType: // maps (map[string]Reader) show up as a MapType @@ -217,7 +217,7 @@ func typeFromAST(in ast.Expr) (Type, error) { return Type{}, err } return Type{ - ID: "map[" + key.String() + "]" + value.String(), + Name: "map[" + key.String() + "]" + value.String(), }, nil case *ast.FuncType: // funcs (func() Reader) show up as a FuncType @@ -270,7 +270,7 @@ func typeFromAST(in ast.Expr) (Type, error) { res.WriteString(")") } } - return Type{ID: res.String()}, nil + return Type{Name: res.String()}, nil case *ast.ChanType: var res strings.Builder // channels (chan Reader) show up as a ChanType @@ -289,7 +289,7 @@ func typeFromAST(in ast.Expr) (Type, error) { if err != nil { return Type{}, err } - valType.ID = res.String() + valType.ID + valType.Name = res.String() + valType.Name return valType, nil case *ast.IndexExpr: // a generic type with one type parameter (Reader[Foo]) shows up as an IndexExpr @@ -305,7 +305,7 @@ func typeFromAST(in ast.Expr) (Type, error) { return Type{}, err } return Type{ - ID: id.ID, + Name: id.Name, Params: []string{param.String()}, }, nil case *ast.IndexListExpr: @@ -318,7 +318,7 @@ func typeFromAST(in ast.Expr) (Type, error) { return Type{}, fmt.Errorf("got type parameters for a type ID, which is very confusing: %s", id.String()) } res := Type{ - ID: id.ID, + Name: id.Name, } for _, typeParam := range specType.Indices { param, err := typeFromAST(typeParam) @@ -408,7 +408,7 @@ func typeSpec(path string, typ Type, srcDir string) (Pkg, Spec, error) { } for _, spec := range decl.Specs { spec := spec.(*ast.TypeSpec) - if spec.Name.Name != typ.ID { + if spec.Name.Name != typ.Name { continue } typeParams, ok := matchTypeParams(spec, typ.Params) @@ -421,7 +421,7 @@ func typeSpec(path string, typ Type, srcDir string) (Pkg, Spec, error) { } } } - return Pkg{}, Spec{}, fmt.Errorf("type %s not found in %s", typ.ID, path) + return Pkg{}, Spec{}, fmt.Errorf("type %s not found in %s", typ.Name, path) } // matchTypeParams returns a map of type parameters from a parsed interface @@ -708,7 +708,7 @@ impl [-dir directory] fmt.Fprint(os.Stderr, ` Examples: - + impl 'f *File' io.Reader impl Murmur hash.Hash impl -dir $GOPATH/src/github.com/josharian/impl Murmur hash.Hash @@ -735,13 +735,13 @@ to prevent shell globbing. } } - var recvPkg = *flagRecvPkg + recvPkg := *flagRecvPkg if recvPkg == "" { // " s *Struct " , receiver: Struct recvs := strings.Fields(recv) receiver := recvs[len(recvs)-1] // note that this correctly handles "s *Struct" and "*Struct" receiver = strings.TrimPrefix(receiver, "*") - pkg, _, err := typeSpec("", Type{ID: receiver}, *flagSrcDir) + pkg, _, err := typeSpec("", Type{Name: receiver}, *flagSrcDir) if err == nil { recvPkg = pkg.Package.Name } diff --git a/impl_test.go b/impl_test.go index 4e809fa..412e982 100644 --- a/impl_test.go +++ b/impl_test.go @@ -25,17 +25,17 @@ func TestFindInterface(t *testing.T) { typ Type wantErr bool }{ - {input: "net.Conn", path: "net", typ: Type{ID: "Conn"}}, - {input: "http.ResponseWriter", path: "net/http", typ: Type{ID: "ResponseWriter"}}, + {input: "net.Conn", path: "net", typ: Type{Name: "Conn"}}, + {input: "http.ResponseWriter", path: "net/http", typ: Type{Name: "ResponseWriter"}}, {input: "net.Tennis", wantErr: true}, {input: "a + b", wantErr: true}, {input: "a/b/c/", wantErr: true}, {input: "a/b/c/pkg", wantErr: true}, {input: "a/b/c/pkg.", wantErr: true}, - {input: "a/b/c/pkg.Typ", path: "a/b/c/pkg", typ: Type{ID: "Typ"}}, - {input: "gopkg.in/yaml.v2.Unmarshaler", path: "gopkg.in/yaml.v2", typ: Type{ID: "Unmarshaler"}}, - {input: "github.com/josharian/impl/testdata.GenericInterface1[string]", path: "github.com/josharian/impl/testdata", typ: Type{ID: "GenericInterface1", Params: []string{"string"}}}, - {input: "github.com/josharian/impl/testdata.GenericInterface1[*string]", path: "github.com/josharian/impl/testdata", typ: Type{ID: "GenericInterface1", Params: []string{"*string"}}}, + {input: "a/b/c/pkg.Typ", path: "a/b/c/pkg", typ: Type{Name: "Typ"}}, + {input: "gopkg.in/yaml.v2.Unmarshaler", path: "gopkg.in/yaml.v2", typ: Type{Name: "Unmarshaler"}}, + {input: "github.com/josharian/impl/testdata.GenericInterface1[string]", path: "github.com/josharian/impl/testdata", typ: Type{Name: "GenericInterface1", Params: []string{"string"}}}, + {input: "github.com/josharian/impl/testdata.GenericInterface1[*string]", path: "github.com/josharian/impl/testdata", typ: Type{Name: "GenericInterface1", Params: []string{"*string"}}}, } for _, tt := range cases { @@ -50,8 +50,8 @@ func TestFindInterface(t *testing.T) { if tt.path != path { t.Errorf("findInterface(%q).path=%q want %q", tt.input, path, tt.path) } - if tt.typ.ID != typ.ID { - t.Errorf("findInterface(%q).id=%q want %q", tt.input, typ.ID, tt.typ.ID) + if tt.typ.Name != typ.Name { + t.Errorf("findInterface(%q).id=%q want %q", tt.input, typ.Name, tt.typ.Name) } if len(tt.typ.Params) != len(typ.Params) { t.Errorf("findInterface(%q).len(typeParams)=%d want %d", tt.input, len(typ.Params), len(tt.typ.Params)) @@ -72,8 +72,8 @@ func TestTypeSpec(t *testing.T) { typ Type wantErr bool }{ - {path: "net", typ: Type{ID: "Conn"}}, - {path: "net", typ: Type{ID: "Con"}, wantErr: true}, + {path: "net", typ: Type{Name: "Conn"}}, + {path: "net", typ: Type{Name: "Con"}, wantErr: true}, } for _, tt := range cases { @@ -340,69 +340,70 @@ func TestValidMethodComments(t *testing.T) { { iface: "github.com/josharian/impl/testdata.Interface1", want: []Func{ - Func{ + { Name: "Method1", Params: []Param{ - Param{ + { Name: "arg1", Type: "string", - }, Param{ + }, { Name: "arg2", Type: "string", - }}, + }, + }, Res: []Param{ - Param{ + { Name: "result", Type: "string", }, - Param{ + { Name: "err", Type: "error", }, }, Comments: "// Method1 is the first method of Interface1.\n", }, - Func{ + { Name: "Method2", Params: []Param{ - Param{ + { Name: "arg1", Type: "int", }, - Param{ + { Name: "arg2", Type: "int", }, }, Res: []Param{ - Param{ + { Name: "result", Type: "int", }, - Param{ + { Name: "err", Type: "error", }, }, Comments: "// Method2 is the second method of Interface1.\n", }, - Func{ + { Name: "Method3", Params: []Param{ - Param{ + { Name: "arg1", Type: "bool", }, - Param{ + { Name: "arg2", Type: "bool", }, }, Res: []Param{ - Param{ + { Name: "result", Type: "bool", }, - Param{ + { Name: "err", Type: "error", }, @@ -414,72 +415,72 @@ func TestValidMethodComments(t *testing.T) { { iface: "github.com/josharian/impl/testdata.Interface2", want: []Func{ - Func{ + { Name: "Method1", Params: []Param{ - Param{ + { Name: "arg1", Type: "int64", }, - Param{ + { Name: "arg2", Type: "int64", }, }, Res: []Param{ - Param{ + { Name: "result", Type: "int64", }, - Param{ + { Name: "err", Type: "error", }, }, Comments: "/*\n\t\tMethod1 is the first method of Interface2.\n\t*/\n", }, - Func{ + { Name: "Method2", Params: []Param{ - Param{ + { Name: "arg1", Type: "float64", }, - Param{ + { Name: "arg2", Type: "float64", }, }, Res: []Param{ - Param{ + { Name: "result", Type: "float64", }, - Param{ + { Name: "err", Type: "error", }, }, Comments: "/*\n\t\tMethod2 is the second method of Interface2.\n\t*/\n", }, - Func{ + { Name: "Method3", Params: []Param{ - Param{ + { Name: "arg1", Type: "interface{}", }, - Param{ + { Name: "arg2", Type: "interface{}", }, }, Res: []Param{ - Param{ + { Name: "result", Type: "interface{}", }, - Param{ + { Name: "err", Type: "error", }, @@ -491,69 +492,70 @@ func TestValidMethodComments(t *testing.T) { { iface: "github.com/josharian/impl/testdata.Interface3", want: []Func{ - Func{ + { Name: "Method1", Params: []Param{ - Param{ + { Name: "_", Type: "string", - }, Param{ + }, { Name: "_", Type: "string", - }}, + }, + }, Res: []Param{ - Param{ + { Name: "", Type: "string", }, - Param{ + { Name: "", Type: "error", }, }, Comments: "// Method1 is the first method of Interface3.\n", }, - Func{ + { Name: "Method2", Params: []Param{ - Param{ + { Name: "_", Type: "int", }, - Param{ + { Name: "arg2", Type: "int", }, }, Res: []Param{ - Param{ + { Name: "_", Type: "int", }, - Param{ + { Name: "err", Type: "error", }, }, Comments: "// Method2 is the second method of Interface3.\n", }, - Func{ + { Name: "Method3", Params: []Param{ - Param{ + { Name: "arg1", Type: "bool", }, - Param{ + { Name: "arg2", Type: "bool", }, }, Res: []Param{ - Param{ + { Name: "result1", Type: "bool", }, - Param{ + { Name: "result2", Type: "bool", }, @@ -765,41 +767,41 @@ func TestParseTypeParams(t *testing.T) { want Type wantErr bool }{ - {desc: "non-generic type", input: "Reader", want: Type{ID: "Reader"}}, - {desc: "one type param", input: "Reader[Foo]", want: Type{ID: "Reader", Params: []string{"Foo"}}}, - {desc: "two type params", input: "Reader[Foo, Bar]", want: Type{ID: "Reader", Params: []string{"Foo", "Bar"}}}, - {desc: "three type params", input: "Reader[Foo, Bar, Baz]", want: Type{ID: "Reader", Params: []string{"Foo", "Bar", "Baz"}}}, - {desc: "no spaces", input: "Reader[Foo,Bar]", want: Type{ID: "Reader", Params: []string{"Foo", "Bar"}}}, + {desc: "non-generic type", input: "Reader", want: Type{Name: "Reader"}}, + {desc: "one type param", input: "Reader[Foo]", want: Type{Name: "Reader", Params: []string{"Foo"}}}, + {desc: "two type params", input: "Reader[Foo, Bar]", want: Type{Name: "Reader", Params: []string{"Foo", "Bar"}}}, + {desc: "three type params", input: "Reader[Foo, Bar, Baz]", want: Type{Name: "Reader", Params: []string{"Foo", "Bar", "Baz"}}}, + {desc: "no spaces", input: "Reader[Foo,Bar]", want: Type{Name: "Reader", Params: []string{"Foo", "Bar"}}}, {desc: "unclosed brackets", input: "Reader[Foo", wantErr: true}, {desc: "no params", input: "Reader[]", wantErr: true}, {desc: "space-only params", input: "Reader[ ]", wantErr: true}, {desc: "multiple space-only params", input: "Reader[ , , ]", wantErr: true}, {desc: "characters after bracket", input: "Reader[Foo]Bar", wantErr: true}, - {desc: "qualified generic type", input: "io.Reader[Foo]", want: Type{ID: "io.Reader", Params: []string{"Foo"}}}, - {desc: "qualified generic type with two params", input: "io.Reader[Foo, Bar]", want: Type{ID: "io.Reader", Params: []string{"Foo", "Bar"}}}, - {desc: "qualified generic param", input: "Reader[io.Reader]", want: Type{ID: "Reader", Params: []string{"io.Reader"}}}, - {desc: "qualified and unqualified generic param", input: "Reader[io.Reader, string]", want: Type{ID: "Reader", Params: []string{"io.Reader", "string"}}}, - {desc: "pointer qualified generic param", input: "Reader[*io.Reader]", want: Type{ID: "Reader", Params: []string{"*io.Reader"}}}, - {desc: "map generic param", input: "Reader[map[string]string]", want: Type{ID: "Reader", Params: []string{"map[string]string"}}}, - {desc: "pointer map generic param", input: "Reader[*map[string]string]", want: Type{ID: "Reader", Params: []string{"*map[string]string"}}}, - {desc: "pointer key map generic param", input: "Reader[map[*string]string]", want: Type{ID: "Reader", Params: []string{"map[*string]string"}}}, - {desc: "pointer value map generic param", input: "Reader[map[string]*string]", want: Type{ID: "Reader", Params: []string{"map[string]*string"}}}, - {desc: "slice generic param", input: "Reader[[]string]", want: Type{ID: "Reader", Params: []string{"[]string"}}}, - {desc: "pointer slice generic param", input: "Reader[*[]string]", want: Type{ID: "Reader", Params: []string{"*[]string"}}}, - {desc: "pointer slice value generic param", input: "Reader[[]*string]", want: Type{ID: "Reader", Params: []string{"[]*string"}}}, - {desc: "array generic param", input: "Reader[[1]string]", want: Type{ID: "Reader", Params: []string{"[1]string"}}}, - {desc: "pointer array generic param", input: "Reader[*[1]string]", want: Type{ID: "Reader", Params: []string{"*[1]string"}}}, - {desc: "pointer array value generic param", input: "Reader[[1]*string]", want: Type{ID: "Reader", Params: []string{"[1]*string"}}}, - {desc: "chan generic param", input: "Reader[chan error]", want: Type{ID: "Reader", Params: []string{"chan error"}}}, - {desc: "receiver chan generic param", input: "Reader[<-chan error]", want: Type{ID: "Reader", Params: []string{"<-chan error"}}}, - {desc: "send chan generic param", input: "Reader[chan<- error]", want: Type{ID: "Reader", Params: []string{"chan<- error"}}}, - {desc: "pointer chan generic param", input: "Reader[*chan error]", want: Type{ID: "Reader", Params: []string{"*chan error"}}}, - {desc: "func generic param", input: "Reader[func() string]", want: Type{ID: "Reader", Params: []string{"func() string"}}}, - {desc: "one arg func generic param", input: "Reader[func(a int) string]", want: Type{ID: "Reader", Params: []string{"func(a int) string"}}}, - {desc: "two arg one type func generic param", input: "Reader[func(a, b int) string]", want: Type{ID: "Reader", Params: []string{"func(a, b int) string"}}}, - {desc: "three arg one type func generic param", input: "Reader[func(a, b, c int) string]", want: Type{ID: "Reader", Params: []string{"func(a, b, c int) string"}}}, - {desc: "three arg two types func generic param", input: "Reader[func(a, b string, c int) string]", want: Type{ID: "Reader", Params: []string{"func(a, b string, c int) string"}}}, - {desc: "three arg three types func generic param", input: "Reader[func(a bool, b string, c int) string]", want: Type{ID: "Reader", Params: []string{"func(a bool, b string, c int) string"}}}, + {desc: "qualified generic type", input: "io.Reader[Foo]", want: Type{Name: "io.Reader", Params: []string{"Foo"}}}, + {desc: "qualified generic type with two params", input: "io.Reader[Foo, Bar]", want: Type{Name: "io.Reader", Params: []string{"Foo", "Bar"}}}, + {desc: "qualified generic param", input: "Reader[io.Reader]", want: Type{Name: "Reader", Params: []string{"io.Reader"}}}, + {desc: "qualified and unqualified generic param", input: "Reader[io.Reader, string]", want: Type{Name: "Reader", Params: []string{"io.Reader", "string"}}}, + {desc: "pointer qualified generic param", input: "Reader[*io.Reader]", want: Type{Name: "Reader", Params: []string{"*io.Reader"}}}, + {desc: "map generic param", input: "Reader[map[string]string]", want: Type{Name: "Reader", Params: []string{"map[string]string"}}}, + {desc: "pointer map generic param", input: "Reader[*map[string]string]", want: Type{Name: "Reader", Params: []string{"*map[string]string"}}}, + {desc: "pointer key map generic param", input: "Reader[map[*string]string]", want: Type{Name: "Reader", Params: []string{"map[*string]string"}}}, + {desc: "pointer value map generic param", input: "Reader[map[string]*string]", want: Type{Name: "Reader", Params: []string{"map[string]*string"}}}, + {desc: "slice generic param", input: "Reader[[]string]", want: Type{Name: "Reader", Params: []string{"[]string"}}}, + {desc: "pointer slice generic param", input: "Reader[*[]string]", want: Type{Name: "Reader", Params: []string{"*[]string"}}}, + {desc: "pointer slice value generic param", input: "Reader[[]*string]", want: Type{Name: "Reader", Params: []string{"[]*string"}}}, + {desc: "array generic param", input: "Reader[[1]string]", want: Type{Name: "Reader", Params: []string{"[1]string"}}}, + {desc: "pointer array generic param", input: "Reader[*[1]string]", want: Type{Name: "Reader", Params: []string{"*[1]string"}}}, + {desc: "pointer array value generic param", input: "Reader[[1]*string]", want: Type{Name: "Reader", Params: []string{"[1]*string"}}}, + {desc: "chan generic param", input: "Reader[chan error]", want: Type{Name: "Reader", Params: []string{"chan error"}}}, + {desc: "receiver chan generic param", input: "Reader[<-chan error]", want: Type{Name: "Reader", Params: []string{"<-chan error"}}}, + {desc: "send chan generic param", input: "Reader[chan<- error]", want: Type{Name: "Reader", Params: []string{"chan<- error"}}}, + {desc: "pointer chan generic param", input: "Reader[*chan error]", want: Type{Name: "Reader", Params: []string{"*chan error"}}}, + {desc: "func generic param", input: "Reader[func() string]", want: Type{Name: "Reader", Params: []string{"func() string"}}}, + {desc: "one arg func generic param", input: "Reader[func(a int) string]", want: Type{Name: "Reader", Params: []string{"func(a int) string"}}}, + {desc: "two arg one type func generic param", input: "Reader[func(a, b int) string]", want: Type{Name: "Reader", Params: []string{"func(a, b int) string"}}}, + {desc: "three arg one type func generic param", input: "Reader[func(a, b, c int) string]", want: Type{Name: "Reader", Params: []string{"func(a, b, c int) string"}}}, + {desc: "three arg two types func generic param", input: "Reader[func(a, b string, c int) string]", want: Type{Name: "Reader", Params: []string{"func(a, b string, c int) string"}}}, + {desc: "three arg three types func generic param", input: "Reader[func(a bool, b string, c int) string]", want: Type{Name: "Reader", Params: []string{"func(a bool, b string, c int) string"}}}, // don't need support for generics on the function type itself; function types must have no type parameters // https://cs.opensource.google/go/go/+/master:src/go/parser/parser.go;l=1048;drc=cafb49ac731f862f386862d64b27b8314eeb2909 } @@ -815,8 +817,8 @@ func TestParseTypeParams(t *testing.T) { } t.Fatalf("unexpected error: %s", err) } - if typ.ID != tt.want.ID { - t.Errorf("wanted ID %q, got %q", tt.want.ID, typ.ID) + if typ.Name != tt.want.Name { + t.Errorf("wanted ID %q, got %q", tt.want.Name, typ.Name) } if len(typ.Params) != len(tt.want.Params) { t.Errorf("wanted %d params, got %d: %v", len(tt.want.Params), len(typ.Params), typ.Params) From 9da095f757c7af661f242c1e9523f5e014f6dec6 Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Tue, 16 Jan 2024 10:14:50 -0800 Subject: [PATCH 14/20] simplify code --- impl.go | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/impl.go b/impl.go index 8ce5d18..755e923 100644 --- a/impl.go +++ b/impl.go @@ -47,23 +47,18 @@ type Type struct { } // String constructs a reference to the Type. For example: -// Type{ID: "Foo", Params{"Bar", "Baz[[]Quux]"}} +// Type{Name: "Foo", Params{"Bar", "Baz[[]Quux]"}} // would yield // Foo[Bar, Baz[[]Quux]] func (t Type) String() string { - var res strings.Builder - res.WriteString(t.Name) if len(t.Params) < 1 { - return res.String() + return t.Name } - res.WriteString("[") - res.WriteString(strings.Join(t.Params, ", ")) - res.WriteString("]") - return res.String() + return t.Name + "[" + strings.Join(t.Params, ", ") + "]" } // parseType parses an interface reference into a Type, allowing us to -// distinguish between the interface's ID or name and its type parameters. +// distinguish between the interface's name and its type parameters. func parseType(in string) (Type, error) { expr, err := parser.ParseExpr(in) if err != nil { From cce7cbde9e0e9b49a006dbaf72b0ac0e1361f8d3 Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Tue, 16 Jan 2024 10:15:01 -0800 Subject: [PATCH 15/20] clean docs --- impl.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/impl.go b/impl.go index 755e923..605ae2e 100644 --- a/impl.go +++ b/impl.go @@ -69,17 +69,16 @@ func parseType(in string) (Type, error) { // findInterface returns the import path and type of an interface. // For example, given "http.ResponseWriter", findInterface returns -// "net/http", Type{ID: "ResponseWriter"}. +// "net/http", Type{Name: "ResponseWriter"}. // If a fully qualified interface is given, such as "net/http.ResponseWriter", // it simply parses the input. // If an unqualified interface such as "UserDefinedInterface" is given, then // the interface definition is presumed to be in the package within srcDir and -// findInterface returns "", Type{ID: "UserDefinedInterface"}. +// findInterface returns "", Type{Name: "UserDefinedInterface"}. // -// Generic types will have their type params returned in the Params property of +// Generic types will have their type params set in the Params property of // the Type. Input should always reference generic types with their parameters -// filled, i.e. GenericType[string, bool], as opposed to -// GenericType[A any, B comparable]. +// specified: GenericType[string, bool], not GenericType[A any, B comparable]. func findInterface(input string, srcDir string) (path string, iface Type, err error) { if len(strings.Fields(input)) != 1 && !strings.Contains(input, "[") { return "", Type{}, fmt.Errorf("couldn't parse interface: %s", input) From 1e06e7262ca92a0e228fbe0dedc403cc720e5a00 Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Tue, 16 Jan 2024 10:15:10 -0800 Subject: [PATCH 16/20] add a test case --- impl_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/impl_test.go b/impl_test.go index 412e982..8320e98 100644 --- a/impl_test.go +++ b/impl_test.go @@ -29,6 +29,7 @@ func TestFindInterface(t *testing.T) { {input: "http.ResponseWriter", path: "net/http", typ: Type{Name: "ResponseWriter"}}, {input: "net.Tennis", wantErr: true}, {input: "a + b", wantErr: true}, + {input: "t[T,U]", path: "", typ: Type{Name: "t", Params: []string{"T", "U"}}}, {input: "a/b/c/", wantErr: true}, {input: "a/b/c/pkg", wantErr: true}, {input: "a/b/c/pkg.", wantErr: true}, From 0b027674e49d950c948f1486fb124f97b4bc575f Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Tue, 16 Jan 2024 10:21:19 -0800 Subject: [PATCH 17/20] simplify error message --- impl.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/impl.go b/impl.go index 605ae2e..bdf3a2e 100644 --- a/impl.go +++ b/impl.go @@ -292,7 +292,7 @@ func typeFromAST(in ast.Expr) (Type, error) { return Type{}, err } if len(id.Params) > 0 { - return Type{}, fmt.Errorf("got type parameters for a type ID, which is very confusing: %s", id.String()) + return Type{}, fmt.Errorf("got type parameters for a type name: %s", id.String()) } param, err := typeFromAST(specType.Index) if err != nil { @@ -309,7 +309,7 @@ func typeFromAST(in ast.Expr) (Type, error) { return Type{}, err } if len(id.Params) > 0 { - return Type{}, fmt.Errorf("got type parameters for a type ID, which is very confusing: %s", id.String()) + return Type{}, fmt.Errorf("got type parameters for a type ID: %s", id.String()) } res := Type{ Name: id.Name, From cc2523e288fac68c1134ec4b390be0ff1298ad62 Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Tue, 16 Jan 2024 10:24:29 -0800 Subject: [PATCH 18/20] reduce indentation --- impl.go | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/impl.go b/impl.go index bdf3a2e..aef48ec 100644 --- a/impl.go +++ b/impl.go @@ -424,23 +424,24 @@ func typeSpec(path string, typ Type, srcDir string) (Pkg, Spec, error) { // passed type, a nil map and false are returned. No type checking is done, // only that there are sufficient types to match. func matchTypeParams(spec *ast.TypeSpec, params []string) (map[string]string, bool) { + if spec.TypeParams == nil { + return nil, true + } res := make(map[string]string, len(params)) - if spec.TypeParams != nil { - var specParamNames []string - for _, typeParam := range spec.TypeParams.List { - for _, name := range typeParam.Names { - if name == nil { - continue - } - specParamNames = append(specParamNames, name.Name) + var specParamNames []string + for _, typeParam := range spec.TypeParams.List { + for _, name := range typeParam.Names { + if name == nil { + continue } + specParamNames = append(specParamNames, name.Name) } - if len(specParamNames) != len(params) { - return nil, false - } - for pos, specParamName := range specParamNames { - res[specParamName] = params[pos] - } + } + if len(specParamNames) != len(params) { + return nil, false + } + for pos, specParamName := range specParamNames { + res[specParamName] = params[pos] } return res, true } From e0454a725c845c4b2a3915730b5e34f1a32420f4 Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Tue, 16 Jan 2024 10:29:38 -0800 Subject: [PATCH 19/20] fix CI, I hope --- .github/workflows/test.yml | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index abeb1cc..4d3cef4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,13 +4,14 @@ on: [push, pull_request] jobs: test: - runs-on: 'ubuntu-latest' + runs-on: "ubuntu-latest" steps: - - uses: actions/checkout@master - - uses: actions/setup-go@v2 - with: - go-version: 1.14 - - name: run go tests - run: | - go test -v ./... - go test -v -race ./... + - uses: actions/checkout@master + - uses: actions/setup-go@v4 + with: + go-version-file: "go.mod" + cache: true + - name: run go tests + run: | + go test -v ./... + go test -v -race ./... From 5fa0fa9550f2f53004b32ccd66ef33dbb6d71267 Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Wed, 17 Jan 2024 10:30:31 -0800 Subject: [PATCH 20/20] simplify code by using format.Node --- impl.go | 183 +++++++------------------------------------------------- 1 file changed, 20 insertions(+), 163 deletions(-) diff --git a/impl.go b/impl.go index aef48ec..5182e82 100644 --- a/impl.go +++ b/impl.go @@ -170,151 +170,29 @@ func findInterface(input string, srcDir string) (path string, iface Type, err er } func typeFromAST(in ast.Expr) (Type, error) { - switch specType := in.(type) { - case *ast.Ident: - // a standalone identifier (Reader) shows up as an Ident - return Type{Name: specType.Name}, nil - case *ast.SelectorExpr: - // an identifier in a different package (io.Reader) shows up as a SelectorExpr - // we need to pull the name out - return Type{Name: specType.X.(*ast.Ident).Name + "." + specType.Sel.Name}, nil - case *ast.StarExpr: - // pointer identifiers (*Reader) show up as a StarExpr - // we need to pull the name out and prefix it with a * - typ, err := typeFromAST(specType.X) - if err != nil { - return Type{}, err - } - typ.Name = "*" + typ.Name - return typ, nil - case *ast.ArrayType: - // slices and arrays ([]Reader) show up as an ArrayType - typ, err := typeFromAST(specType.Elt) - if err != nil { - return Type{}, err - } - prefix := "[" - if specType.Len != nil { - prefix += specType.Len.(*ast.BasicLit).Value - } - prefix += "]" - typ.Name = prefix + typ.Name - return typ, nil - case *ast.MapType: - // maps (map[string]Reader) show up as a MapType - key, err := typeFromAST(specType.Key) - if err != nil { - return Type{}, err - } - value, err := typeFromAST(specType.Value) - if err != nil { - return Type{}, err - } - return Type{ - Name: "map[" + key.String() + "]" + value.String(), - }, nil - case *ast.FuncType: - // funcs (func() Reader) show up as a FuncType - // NOTE: we don't actually parse out the type params of a FuncType. - // This should be okay, because we really only care about - // parsing out the type params when parsing interface - // identifiers. And FuncTypes never signify an interface - // identifier, they're just an argument to it or a type param - // of it. - // We don't parse them out anyways like we do for everything - // else because funcs, alone, are pretty weird in how they use - // generics. - // For everything else, it's identifier[params]. - // For funcs, the params get stuck in the middle of the identifier: - // func Foo[Param1, Param2](context.Context, Param1) Param2 - // We're gonna need to complicate everything to support that - // construction, and we don't actually need the deconstructed - // bits, so we're just... not going to deconstruct it at all. - var res strings.Builder - res.WriteString("func") - if specType.TypeParams != nil && len(specType.TypeParams.List) > 0 { - res.WriteString("[") - paramList, err := buildFuncParamList(specType.TypeParams.List) - if err != nil { - return Type{}, err - } - res.WriteString(paramList) - res.WriteString("]") - } - res.WriteString("(") - if specType.Params != nil { - paramList, err := buildFuncParamList(specType.Params.List) - if err != nil { - return Type{}, err - } - res.WriteString(paramList) - } - res.WriteString(")") - if specType.Results != nil && len(specType.Results.List) > 0 { - res.WriteString(" ") - if len(specType.Results.List) > 1 { - res.WriteString("(") - } - paramList, err := buildFuncParamList(specType.Results.List) - if err != nil { - return Type{}, err - } - res.WriteString(paramList) - if len(specType.Results.List) > 1 { - res.WriteString(")") - } - } - return Type{Name: res.String()}, nil - case *ast.ChanType: - var res strings.Builder - // channels (chan Reader) show up as a ChanType - // we need to be careful to preserve send/receive semantics - if specType.Dir&ast.SEND == 0 { - // this is a receive-only channel, write the arrow before the chan keyword - res.WriteString("<-") - } - res.WriteString("chan") - if specType.Dir&ast.RECV == 0 { - // this is a send-only channel, write the arrow after the chan keyword - res.WriteString("<-") - } - res.WriteString(" ") - valType, err := typeFromAST(specType.Value) - if err != nil { - return Type{}, err - } - valType.Name = res.String() + valType.Name - return valType, nil + // Extract type name and params from generic types. + var typeName ast.Expr + var typeParams []ast.Expr + switch in := in.(type) { case *ast.IndexExpr: // a generic type with one type parameter (Reader[Foo]) shows up as an IndexExpr - id, err := typeFromAST(specType.X) - if err != nil { - return Type{}, err - } - if len(id.Params) > 0 { - return Type{}, fmt.Errorf("got type parameters for a type name: %s", id.String()) - } - param, err := typeFromAST(specType.Index) - if err != nil { - return Type{}, err - } - return Type{ - Name: id.Name, - Params: []string{param.String()}, - }, nil + typeName = in.X + typeParams = []ast.Expr{in.Index} case *ast.IndexListExpr: // a generic type with multiple type parameters shows up as an IndexListExpr - id, err := typeFromAST(specType.X) + typeName = in.X + typeParams = in.Indices + } + if typeParams != nil { + id, err := typeFromAST(typeName) if err != nil { return Type{}, err } if len(id.Params) > 0 { - return Type{}, fmt.Errorf("got type parameters for a type ID: %s", id.String()) + return Type{}, fmt.Errorf("unexpected type parameters: %v", in) } - res := Type{ - Name: id.Name, - } - for _, typeParam := range specType.Indices { + res := Type{Name: id.Name} + for _, typeParam := range typeParams { param, err := typeFromAST(typeParam) if err != nil { return Type{}, err @@ -323,34 +201,13 @@ func typeFromAST(in ast.Expr) (Type, error) { } return res, nil } - return Type{}, fmt.Errorf("unexpected AST type %T", in) -} - -// buildFuncParamList returns a string representation of a list of function -// params (type params, function arguments, returns) given an []*ast.Field -// for those things. -func buildFuncParamList(list []*ast.Field) (string, error) { - var res strings.Builder - for pos, field := range list { - for namePos, name := range field.Names { - res.WriteString(name.Name) - if namePos+1 < len(field.Names) { - res.WriteString(", ") - } - } - if len(field.Names) > 0 { - res.WriteString(" ") - } - fieldType, err := typeFromAST(field.Type) - if err != nil { - return "", err - } - res.WriteString(fieldType.String()) - if pos+1 < len(list) { - res.WriteString(", ") - } + // Non-generic type. + buf := new(strings.Builder) + err := format.Node(buf, token.NewFileSet(), in) + if err != nil { + return Type{}, err } - return res.String(), nil + return Type{Name: buf.String()}, nil } // Pkg is a parsed build.Package.