From e6bec826194aeeb10d241a8a4630a7a9a740d932 Mon Sep 17 00:00:00 2001 From: Paddy Date: Wed, 17 Jan 2024 10:36:18 -0800 Subject: [PATCH] support generics Add support for generic interfaces, i.e. interfaces that accept a type parameter. For example, you can now have this kind of interface: ```go type Interface[Kind any] interface { DoTheThing() Kind } ``` and if you run `impl 's StringImpl' 'Interface[string]` it will generate the following code: ```go func (s StringImpl) DoTheThing() string { // normal impl stub here } ``` Add a Type abstraction, make type parsing more robust. Instead of parsing generic types using string manipulation, parse them using the go/parse package to get the AST. Rather than passing around a type ID and its params everywhere, create a Type struct that contains both. Fixes josharian/impl#44. Co-authored-by: Josh Bleecher Snyder --- .github/workflows/test.yml | 19 +-- go.mod | 8 +- go.sum | 48 +----- impl.go | 235 +++++++++++++++++++++++------ impl_test.go | 298 +++++++++++++++++++++++++++---------- testdata/interfaces.go | 96 ++++++++++++ 6 files changed, 524 insertions(+), 180 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index abeb1cc..4d3cef4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,13 +4,14 @@ on: [push, pull_request] jobs: test: - runs-on: 'ubuntu-latest' + runs-on: "ubuntu-latest" steps: - - uses: actions/checkout@master - - uses: actions/setup-go@v2 - with: - go-version: 1.14 - - name: run go tests - run: | - go test -v ./... - go test -v -race ./... + - uses: actions/checkout@master + - uses: actions/setup-go@v4 + with: + go-version-file: "go.mod" + cache: true + - name: run go tests + run: | + go test -v ./... + go test -v -race ./... diff --git a/go.mod b/go.mod index 4ed327c..faddeb2 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,10 @@ module github.com/josharian/impl -go 1.14 +go 1.18 + +require golang.org/x/tools v0.17.0 require ( - golang.org/x/tools v0.4.0 - golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect + golang.org/x/mod v0.14.0 // indirect + golang.org/x/sys v0.16.0 // indirect ) diff --git a/go.sum b/go.sum index 626cea3..bfbc802 100644 --- a/go.sum +++ b/go.sum @@ -1,48 +1,12 @@ -github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/mod v0.2.0 h1:KU7oHjnv3XNWfa5COkzUifxZmxp1TyI7ImMXqFxLwvQ= -golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375 h1:SjQ2+AKWgZLc1xej6WSzL+Dfs5Uyd5xcZH1mGC411IA= -golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4= golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= -golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= +golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= diff --git a/impl.go b/impl.go index 17b16dc..5182e82 100644 --- a/impl.go +++ b/impl.go @@ -26,45 +26,96 @@ var ( flagRecvPkg = flag.String("recvpkg", "", "package name of the receiver") ) -// findInterface returns the import path and identifier of an interface. +// Type is a parsed type reference. +type Type struct { + // Name is the type's name. For example, in "foo[Bar, Baz]", the name + // is "foo". + Name string + + // Params are the type's type params. For example, in "foo[Bar, Baz]", + // the Params are []string{"Bar", "Baz"}. + // + // Params never list the type of the "name type" construction of type + // params used when defining a generic type. They will always be just + // the filling type, as seen when using a generic type. + // + // Params will always be the type parameters only for the top-level + // type; if the params themselves have type parameters, they will + // remain joined to the type name. So "foo[Bar, Baz[Quux]]" will be + // returned as {ID: "foo", Params: []string{"Bar", "Baz[Quux]"}} + Params []string +} + +// String constructs a reference to the Type. For example: +// Type{Name: "Foo", Params{"Bar", "Baz[[]Quux]"}} +// would yield +// Foo[Bar, Baz[[]Quux]] +func (t Type) String() string { + if len(t.Params) < 1 { + return t.Name + } + return t.Name + "[" + strings.Join(t.Params, ", ") + "]" +} + +// parseType parses an interface reference into a Type, allowing us to +// distinguish between the interface's name and its type parameters. +func parseType(in string) (Type, error) { + expr, err := parser.ParseExpr(in) + if err != nil { + return Type{}, err + } + return typeFromAST(expr) +} + +// findInterface returns the import path and type of an interface. // For example, given "http.ResponseWriter", findInterface returns -// "net/http", "ResponseWriter". +// "net/http", Type{Name: "ResponseWriter"}. // If a fully qualified interface is given, such as "net/http.ResponseWriter", // it simply parses the input. // If an unqualified interface such as "UserDefinedInterface" is given, then // the interface definition is presumed to be in the package within srcDir and -// findInterface returns "", "UserDefinedInterface". -func findInterface(iface string, srcDir string) (path string, id string, err error) { - if len(strings.Fields(iface)) != 1 { - return "", "", fmt.Errorf("couldn't parse interface: %s", iface) +// findInterface returns "", Type{Name: "UserDefinedInterface"}. +// +// Generic types will have their type params set in the Params property of +// the Type. Input should always reference generic types with their parameters +// specified: GenericType[string, bool], not GenericType[A any, B comparable]. +func findInterface(input string, srcDir string) (path string, iface Type, err error) { + if len(strings.Fields(input)) != 1 && !strings.Contains(input, "[") { + return "", Type{}, fmt.Errorf("couldn't parse interface: %s", input) } srcPath := filepath.Join(srcDir, "__go_impl__.go") - if slash := strings.LastIndex(iface, "/"); slash > -1 { + if slash := strings.LastIndex(input, "/"); slash > -1 { // package path provided - dot := strings.LastIndex(iface, ".") + dot := strings.LastIndex(input, ".") // make sure iface does not end with "/" (e.g. reject net/http/) - if slash+1 == len(iface) { - return "", "", fmt.Errorf("interface name cannot end with a '/' character: %s", iface) + if slash+1 == len(input) { + return "", Type{}, fmt.Errorf("interface name cannot end with a '/' character: %s", input) } // make sure iface does not end with "." (e.g. reject net/http.) - if dot+1 == len(iface) { - return "", "", fmt.Errorf("interface name cannot end with a '.' character: %s", iface) + if dot+1 == len(input) { + return "", Type{}, fmt.Errorf("interface name cannot end with a '.' character: %s", input) } // make sure iface has at least one "." after "/" (e.g. reject net/http/httputil) - if strings.Count(iface[slash:], ".") == 0 { - return "", "", fmt.Errorf("invalid interface name: %s", iface) + if strings.Count(input[slash:], ".") == 0 { + return "", Type{}, fmt.Errorf("invalid interface name: %s", input) } - return iface[:dot], iface[dot+1:], nil + path = input[:dot] + id := input[dot+1:] + iface, err = parseType(id) + if err != nil { + return "", Type{}, err + } + return path, iface, nil } - src := []byte("package hack\n" + "var i " + iface) + src := []byte("package hack\n" + "var i " + input) // If we couldn't determine the import path, goimports will // auto fix the import path. imp, err := imports.Process(srcPath, src, nil) if err != nil { - return "", "", fmt.Errorf("couldn't parse interface: %s", iface) + return "", Type{}, fmt.Errorf("couldn't parse interface: %s", input) } // imp should now contain an appropriate import. @@ -75,10 +126,10 @@ func findInterface(iface string, srcDir string) (path string, id string, err err panic(err) } - qualified := strings.Contains(iface, ".") + qualified := strings.Contains(input, ".") if len(f.Imports) == 0 && qualified { - return "", "", fmt.Errorf("unrecognized interface: %s", iface) + return "", Type{}, fmt.Errorf("unrecognized interface: %s", input) } if !qualified { @@ -87,12 +138,10 @@ func findInterface(iface string, srcDir string) (path string, id string, err err // package hack // // var i Reader - decl := f.Decls[0].(*ast.GenDecl) // var i io.Reader - spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader - sel := spec.Type.(*ast.Ident) - id = sel.Name // Reader - - return path, id, nil + decl := f.Decls[0].(*ast.GenDecl) // var i Reader + spec := decl.Specs[0].(*ast.ValueSpec) // i Reader + iface, err = typeFromAST(spec.Type) + return path, iface, err } // If qualified, the code looks like: @@ -111,10 +160,54 @@ func findInterface(iface string, srcDir string) (path string, id string, err err } decl := f.Decls[1].(*ast.GenDecl) // var i io.Reader spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader - sel := spec.Type.(*ast.SelectorExpr) // io.Reader - id = sel.Sel.Name // Reader + iface, err = typeFromAST(spec.Type) + if err != nil { + return path, iface, fmt.Errorf("error parsing type from AST: %w", err) + } + // trim off the package which got smooshed on when resolving the type + _, iface.Name, _ = strings.Cut(iface.Name, ".") + return path, iface, err +} - return path, id, nil +func typeFromAST(in ast.Expr) (Type, error) { + // Extract type name and params from generic types. + var typeName ast.Expr + var typeParams []ast.Expr + switch in := in.(type) { + case *ast.IndexExpr: + // a generic type with one type parameter (Reader[Foo]) shows up as an IndexExpr + typeName = in.X + typeParams = []ast.Expr{in.Index} + case *ast.IndexListExpr: + // a generic type with multiple type parameters shows up as an IndexListExpr + typeName = in.X + typeParams = in.Indices + } + if typeParams != nil { + id, err := typeFromAST(typeName) + if err != nil { + return Type{}, err + } + if len(id.Params) > 0 { + return Type{}, fmt.Errorf("unexpected type parameters: %v", in) + } + res := Type{Name: id.Name} + for _, typeParam := range typeParams { + param, err := typeFromAST(typeParam) + if err != nil { + return Type{}, err + } + res.Params = append(res.Params, param.String()) + } + return res, nil + } + // Non-generic type. + buf := new(strings.Builder) + err := format.Node(buf, token.NewFileSet(), in) + if err != nil { + return Type{}, err + } + return Type{Name: buf.String()}, nil } // Pkg is a parsed build.Package. @@ -125,20 +218,27 @@ type Pkg struct { recvPkg string } +// Spec is ast.TypeSpec with the associated comment map. +type Spec struct { + *ast.TypeSpec + ast.CommentMap + TypeParams map[string]string +} + // typeSpec locates the *ast.TypeSpec for type id in the import path. -func typeSpec(path string, id string, srcDir string) (Pkg, *ast.TypeSpec, error) { +func typeSpec(path string, typ Type, srcDir string) (Pkg, Spec, error) { var pkg *build.Package var err error if path == "" { pkg, err = build.ImportDir(srcDir, 0) if err != nil { - return Pkg{}, nil, fmt.Errorf("couldn't find package in %s: %v", srcDir, err) + return Pkg{}, Spec{}, fmt.Errorf("couldn't find package in %s: %v", srcDir, err) } } else { pkg, err = build.Import(path, srcDir, 0) if err != nil { - return Pkg{}, nil, fmt.Errorf("couldn't find package %s: %v", path, err) + return Pkg{}, Spec{}, fmt.Errorf("couldn't find package %s: %v", path, err) } } @@ -159,15 +259,48 @@ func typeSpec(path string, id string, srcDir string) (Pkg, *ast.TypeSpec, error) } for _, spec := range decl.Specs { spec := spec.(*ast.TypeSpec) - if spec.Name.Name != id { + if spec.Name.Name != typ.Name { + continue + } + typeParams, ok := matchTypeParams(spec, typ.Params) + if !ok { continue } p := Pkg{Package: pkg, FileSet: fset} - return p, spec, nil + s := Spec{TypeSpec: spec, TypeParams: typeParams} + return p, s, nil } } } - return Pkg{}, nil, fmt.Errorf("type %s not found in %s", id, path) + return Pkg{}, Spec{}, fmt.Errorf("type %s not found in %s", typ.Name, path) +} + +// matchTypeParams returns a map of type parameters from a parsed interface +// definition and the types that fill them from the user's specified type +// info. If the passed params can't be used to fill the type parameters on the +// passed type, a nil map and false are returned. No type checking is done, +// only that there are sufficient types to match. +func matchTypeParams(spec *ast.TypeSpec, params []string) (map[string]string, bool) { + if spec.TypeParams == nil { + return nil, true + } + res := make(map[string]string, len(params)) + var specParamNames []string + for _, typeParam := range spec.TypeParams.List { + for _, name := range typeParam.Names { + if name == nil { + continue + } + specParamNames = append(specParamNames, name.Name) + } + } + if len(specParamNames) != len(params) { + return nil, false + } + for pos, specParamName := range specParamNames { + res[specParamName] = params[pos] + } + return res, true } // gofmt pretty-prints e. @@ -203,15 +336,25 @@ func (p Pkg) fullType(e ast.Expr) string { return p.gofmt(e) } -func (p Pkg) params(field *ast.Field) []Param { +func (p Pkg) params(field *ast.Field, typeParams map[string]string) []Param { var params []Param - typ := p.fullType(field.Type) + var typ string + switch expr := field.Type.(type) { + case *ast.Ident: + if genType, ok := typeParams[expr.Name]; ok { + typ = genType + } else { + typ = p.fullType(field.Type) + } + default: + typ = p.fullType(field.Type) + } for _, name := range field.Names { params = append(params, Param{Name: name.Name, Type: typ}) } // Handle anonymous params if len(params) == 0 { - params = []Param{Param{Type: typ}} + params = []Param{{Type: typ}} } return params } @@ -244,12 +387,12 @@ const ( WithoutComments EmitComments = false ) -func (p Pkg) funcsig(f *ast.Field, comments EmitComments) Func { +func (p Pkg) funcsig(f *ast.Field, typeParams map[string]string, cmap ast.CommentMap, comments EmitComments) Func { fn := Func{Name: f.Names[0].Name} typ := f.Type.(*ast.FuncType) if typ.Params != nil { for _, field := range typ.Params.List { - for _, param := range p.params(field) { + for _, param := range p.params(field, typeParams) { // only for method parameters: // assign a blank identifier "_" to an anonymous parameter if param.Name == "" { @@ -261,7 +404,7 @@ func (p Pkg) funcsig(f *ast.Field, comments EmitComments) Func { } if typ.Results != nil { for _, field := range typ.Results.List { - fn.Res = append(fn.Res, p.params(field)...) + fn.Res = append(fn.Res, p.params(field, typeParams)...) } } if comments == WithComments && f.Doc != nil { @@ -286,13 +429,13 @@ func funcs(iface, srcDir, recvPkg string, comments EmitComments) ([]Func, error) } // Locate the interface. - path, id, err := findInterface(iface, srcDir) + path, typ, err := findInterface(iface, srcDir) if err != nil { return nil, err } // Parse the package and find the interface declaration. - p, spec, err := typeSpec(path, id, srcDir) + p, spec, err := typeSpec(path, typ, srcDir) if err != nil { return nil, fmt.Errorf("interface %s not found: %s", iface, err) } @@ -319,7 +462,7 @@ func funcs(iface, srcDir, recvPkg string, comments EmitComments) ([]Func, error) continue } - fn := p.funcsig(fndecl, comments) + fn := p.funcsig(fndecl, spec.TypeParams, spec.CommentMap.Filter(fndecl), comments) fns = append(fns, fn) } return fns, nil @@ -417,7 +560,7 @@ impl [-dir directory] fmt.Fprint(os.Stderr, ` Examples: - + impl 'f *File' io.Reader impl Murmur hash.Hash impl -dir $GOPATH/src/github.com/josharian/impl Murmur hash.Hash @@ -444,13 +587,13 @@ to prevent shell globbing. } } - var recvPkg = *flagRecvPkg + recvPkg := *flagRecvPkg if recvPkg == "" { // " s *Struct " , receiver: Struct recvs := strings.Fields(recv) receiver := recvs[len(recvs)-1] // note that this correctly handles "s *Struct" and "*Struct" receiver = strings.TrimPrefix(receiver, "*") - pkg, _, err := typeSpec("", receiver, *flagSrcDir) + pkg, _, err := typeSpec("", Type{Name: receiver}, *flagSrcDir) if err == nil { recvPkg = pkg.Package.Name } diff --git a/impl_test.go b/impl_test.go index ffed4e4..8320e98 100644 --- a/impl_test.go +++ b/impl_test.go @@ -20,36 +20,47 @@ func (b errBool) String() string { func TestFindInterface(t *testing.T) { t.Parallel() cases := []struct { - iface string + input string path string - id string + typ Type wantErr bool }{ - {iface: "net.Conn", path: "net", id: "Conn"}, - {iface: "http.ResponseWriter", path: "net/http", id: "ResponseWriter"}, - {iface: "net.Tennis", wantErr: true}, - {iface: "a + b", wantErr: true}, - {iface: "a/b/c/", wantErr: true}, - {iface: "a/b/c/pkg", wantErr: true}, - {iface: "a/b/c/pkg.", wantErr: true}, - {iface: "a/b/c/pkg.Typ", path: "a/b/c/pkg", id: "Typ"}, - {iface: "gopkg.in/yaml.v2.Unmarshaler", path: "gopkg.in/yaml.v2", id: "Unmarshaler"}, + {input: "net.Conn", path: "net", typ: Type{Name: "Conn"}}, + {input: "http.ResponseWriter", path: "net/http", typ: Type{Name: "ResponseWriter"}}, + {input: "net.Tennis", wantErr: true}, + {input: "a + b", wantErr: true}, + {input: "t[T,U]", path: "", typ: Type{Name: "t", Params: []string{"T", "U"}}}, + {input: "a/b/c/", wantErr: true}, + {input: "a/b/c/pkg", wantErr: true}, + {input: "a/b/c/pkg.", wantErr: true}, + {input: "a/b/c/pkg.Typ", path: "a/b/c/pkg", typ: Type{Name: "Typ"}}, + {input: "gopkg.in/yaml.v2.Unmarshaler", path: "gopkg.in/yaml.v2", typ: Type{Name: "Unmarshaler"}}, + {input: "github.com/josharian/impl/testdata.GenericInterface1[string]", path: "github.com/josharian/impl/testdata", typ: Type{Name: "GenericInterface1", Params: []string{"string"}}}, + {input: "github.com/josharian/impl/testdata.GenericInterface1[*string]", path: "github.com/josharian/impl/testdata", typ: Type{Name: "GenericInterface1", Params: []string{"*string"}}}, } for _, tt := range cases { tt := tt - t.Run(tt.iface, func(t *testing.T) { + t.Run(tt.input, func(t *testing.T) { t.Parallel() - path, id, err := findInterface(tt.iface, ".") + path, typ, err := findInterface(tt.input, ".") gotErr := err != nil if tt.wantErr != gotErr { - t.Fatalf("findInterface(%q).err=%v want %s", tt.iface, err, errBool(tt.wantErr)) + t.Fatalf("findInterface(%q).err=%v want %s", tt.input, err, errBool(tt.wantErr)) } if tt.path != path { - t.Errorf("findInterface(%q).path=%q want %q", tt.iface, path, tt.path) + t.Errorf("findInterface(%q).path=%q want %q", tt.input, path, tt.path) } - if tt.id != id { - t.Errorf("findInterface(%q).id=%q want %q", tt.iface, id, tt.id) + if tt.typ.Name != typ.Name { + t.Errorf("findInterface(%q).id=%q want %q", tt.input, typ.Name, tt.typ.Name) + } + if len(tt.typ.Params) != len(typ.Params) { + t.Errorf("findInterface(%q).len(typeParams)=%d want %d", tt.input, len(typ.Params), len(tt.typ.Params)) + } + for pos, v := range tt.typ.Params { + if v != typ.Params[pos] { + t.Errorf("findInterface(%q).typeParams[%d]=%q, want %q", tt.input, pos, typ.Params[pos], v) + } } }) } @@ -59,26 +70,26 @@ func TestTypeSpec(t *testing.T) { // For now, just test whether we can find the interface. cases := []struct { path string - id string + typ Type wantErr bool }{ - {path: "net", id: "Conn"}, - {path: "net", id: "Con", wantErr: true}, + {path: "net", typ: Type{Name: "Conn"}}, + {path: "net", typ: Type{Name: "Con"}, wantErr: true}, } for _, tt := range cases { - p, spec, err := typeSpec(tt.path, tt.id, "") + p, spec, err := typeSpec(tt.path, tt.typ, "") gotErr := err != nil if tt.wantErr != gotErr { - t.Errorf("typeSpec(%q, %q).err=%v want %s", tt.path, tt.id, err, errBool(tt.wantErr)) + t.Errorf("typeSpec(%q, %q).err=%v want %s", tt.path, tt.typ, err, errBool(tt.wantErr)) continue } if err == nil { if reflect.DeepEqual(p, Pkg{}) { - t.Errorf("typeSpec(%q, %q).pkg=Pkg{} want non-nil", tt.path, tt.id) + t.Errorf("typeSpec(%q, %q).pkg=Pkg{} want non-nil", tt.path, tt.typ) } - if spec == nil { - t.Errorf("typeSpec(%q, %q).spec=nil want non-nil", tt.path, tt.id) + if reflect.DeepEqual(spec, Spec{}) { + t.Errorf("typeSpec(%q, %q).spec=Spec{} want non-nil", tt.path, tt.typ) } } } @@ -252,6 +263,25 @@ func TestFuncs(t *testing.T) { comments: WithComments, }, {iface: "net.Tennis", wantErr: true}, + { + iface: "github.com/josharian/impl/testdata.GenericInterface1[int]", + want: []Func{ + { + Name: "Method1", + Res: []Param{{Type: "int"}}, + }, + { + Name: "Method2", + Params: []Param{{Name: "_", Type: "int"}}, + }, + { + Name: "Method3", + Params: []Param{{Name: "_", Type: "int"}}, + Res: []Param{{Type: "int"}}, + }, + }, + comments: WithComments, + }, } for _, tt := range cases { @@ -311,69 +341,70 @@ func TestValidMethodComments(t *testing.T) { { iface: "github.com/josharian/impl/testdata.Interface1", want: []Func{ - Func{ + { Name: "Method1", Params: []Param{ - Param{ + { Name: "arg1", Type: "string", - }, Param{ + }, { Name: "arg2", Type: "string", - }}, + }, + }, Res: []Param{ - Param{ + { Name: "result", Type: "string", }, - Param{ + { Name: "err", Type: "error", }, }, Comments: "// Method1 is the first method of Interface1.\n", }, - Func{ + { Name: "Method2", Params: []Param{ - Param{ + { Name: "arg1", Type: "int", }, - Param{ + { Name: "arg2", Type: "int", }, }, Res: []Param{ - Param{ + { Name: "result", Type: "int", }, - Param{ + { Name: "err", Type: "error", }, }, Comments: "// Method2 is the second method of Interface1.\n", }, - Func{ + { Name: "Method3", Params: []Param{ - Param{ + { Name: "arg1", Type: "bool", }, - Param{ + { Name: "arg2", Type: "bool", }, }, Res: []Param{ - Param{ + { Name: "result", Type: "bool", }, - Param{ + { Name: "err", Type: "error", }, @@ -385,72 +416,72 @@ func TestValidMethodComments(t *testing.T) { { iface: "github.com/josharian/impl/testdata.Interface2", want: []Func{ - Func{ + { Name: "Method1", Params: []Param{ - Param{ + { Name: "arg1", Type: "int64", }, - Param{ + { Name: "arg2", Type: "int64", }, }, Res: []Param{ - Param{ + { Name: "result", Type: "int64", }, - Param{ + { Name: "err", Type: "error", }, }, Comments: "/*\n\t\tMethod1 is the first method of Interface2.\n\t*/\n", }, - Func{ + { Name: "Method2", Params: []Param{ - Param{ + { Name: "arg1", Type: "float64", }, - Param{ + { Name: "arg2", Type: "float64", }, }, Res: []Param{ - Param{ + { Name: "result", Type: "float64", }, - Param{ + { Name: "err", Type: "error", }, }, Comments: "/*\n\t\tMethod2 is the second method of Interface2.\n\t*/\n", }, - Func{ + { Name: "Method3", Params: []Param{ - Param{ + { Name: "arg1", Type: "interface{}", }, - Param{ + { Name: "arg2", Type: "interface{}", }, }, Res: []Param{ - Param{ + { Name: "result", Type: "interface{}", }, - Param{ + { Name: "err", Type: "error", }, @@ -462,69 +493,70 @@ func TestValidMethodComments(t *testing.T) { { iface: "github.com/josharian/impl/testdata.Interface3", want: []Func{ - Func{ + { Name: "Method1", Params: []Param{ - Param{ + { Name: "_", Type: "string", - }, Param{ + }, { Name: "_", Type: "string", - }}, + }, + }, Res: []Param{ - Param{ + { Name: "", Type: "string", }, - Param{ + { Name: "", Type: "error", }, }, Comments: "// Method1 is the first method of Interface3.\n", }, - Func{ + { Name: "Method2", Params: []Param{ - Param{ + { Name: "_", Type: "int", }, - Param{ + { Name: "arg2", Type: "int", }, }, Res: []Param{ - Param{ + { Name: "_", Type: "int", }, - Param{ + { Name: "err", Type: "error", }, }, Comments: "// Method2 is the second method of Interface3.\n", }, - Func{ + { Name: "Method3", Params: []Param{ - Param{ + { Name: "arg1", Type: "bool", }, - Param{ + { Name: "arg2", Type: "bool", }, }, Res: []Param{ - Param{ + { Name: "result1", Type: "bool", }, - Param{ + { Name: "result2", Type: "bool", }, @@ -577,16 +609,48 @@ func TestStubGeneration(t *testing.T) { want: testdata.Interface9Output, dir: ".", }, + { + iface: "github.com/josharian/impl/testdata.GenericInterface1[string]", + want: testdata.GenericInterface1Output, + dir: ".", + }, + { + iface: "GenericInterface1[string]", + want: testdata.GenericInterface1Output, + dir: "testdata", + }, + { + iface: "github.com/josharian/impl/testdata.GenericInterface2[string, bool]", + want: testdata.GenericInterface2Output, + dir: ".", + }, + { + iface: "GenericInterface2[string, bool]", + want: testdata.GenericInterface2Output, + dir: "testdata", + }, + { + iface: "github.com/josharian/impl/testdata.GenericInterface3[string, bool]", + want: testdata.GenericInterface3Output, + dir: ".", + }, + { + iface: "GenericInterface3[string, bool]", + want: testdata.GenericInterface3Output, + dir: "testdata", + }, } for _, tt := range cases { - fns, err := funcs(tt.iface, tt.dir, "", WithComments) - if err != nil { - t.Errorf("funcs(%q).err=%v", tt.iface, err) - } - src := genStubs("r *Receiver", fns, nil) - if string(src) != tt.want { - t.Errorf("genStubs(\"r *Receiver\", %+#v).src=\n%#v\nwant\n%#v\n", fns, string(src), tt.want) - } + t.Run(tt.iface, func(t *testing.T) { + fns, err := funcs(tt.iface, tt.dir, "", WithComments) + if err != nil { + t.Errorf("funcs(%q).err=%v", tt.iface, err) + } + src := genStubs("r *Receiver", fns, nil) + if string(src) != tt.want { + t.Errorf("genStubs(\"r *Receiver\", %+#v).src=\n%#v\nwant\n%#v\n", fns, string(src), tt.want) + } + }) } } @@ -694,3 +758,77 @@ func TestStubGenerationForRepeatedName(t *testing.T) { }) } } + +func TestParseTypeParams(t *testing.T) { + t.Parallel() + + cases := []struct { + desc string + input string + want Type + wantErr bool + }{ + {desc: "non-generic type", input: "Reader", want: Type{Name: "Reader"}}, + {desc: "one type param", input: "Reader[Foo]", want: Type{Name: "Reader", Params: []string{"Foo"}}}, + {desc: "two type params", input: "Reader[Foo, Bar]", want: Type{Name: "Reader", Params: []string{"Foo", "Bar"}}}, + {desc: "three type params", input: "Reader[Foo, Bar, Baz]", want: Type{Name: "Reader", Params: []string{"Foo", "Bar", "Baz"}}}, + {desc: "no spaces", input: "Reader[Foo,Bar]", want: Type{Name: "Reader", Params: []string{"Foo", "Bar"}}}, + {desc: "unclosed brackets", input: "Reader[Foo", wantErr: true}, + {desc: "no params", input: "Reader[]", wantErr: true}, + {desc: "space-only params", input: "Reader[ ]", wantErr: true}, + {desc: "multiple space-only params", input: "Reader[ , , ]", wantErr: true}, + {desc: "characters after bracket", input: "Reader[Foo]Bar", wantErr: true}, + {desc: "qualified generic type", input: "io.Reader[Foo]", want: Type{Name: "io.Reader", Params: []string{"Foo"}}}, + {desc: "qualified generic type with two params", input: "io.Reader[Foo, Bar]", want: Type{Name: "io.Reader", Params: []string{"Foo", "Bar"}}}, + {desc: "qualified generic param", input: "Reader[io.Reader]", want: Type{Name: "Reader", Params: []string{"io.Reader"}}}, + {desc: "qualified and unqualified generic param", input: "Reader[io.Reader, string]", want: Type{Name: "Reader", Params: []string{"io.Reader", "string"}}}, + {desc: "pointer qualified generic param", input: "Reader[*io.Reader]", want: Type{Name: "Reader", Params: []string{"*io.Reader"}}}, + {desc: "map generic param", input: "Reader[map[string]string]", want: Type{Name: "Reader", Params: []string{"map[string]string"}}}, + {desc: "pointer map generic param", input: "Reader[*map[string]string]", want: Type{Name: "Reader", Params: []string{"*map[string]string"}}}, + {desc: "pointer key map generic param", input: "Reader[map[*string]string]", want: Type{Name: "Reader", Params: []string{"map[*string]string"}}}, + {desc: "pointer value map generic param", input: "Reader[map[string]*string]", want: Type{Name: "Reader", Params: []string{"map[string]*string"}}}, + {desc: "slice generic param", input: "Reader[[]string]", want: Type{Name: "Reader", Params: []string{"[]string"}}}, + {desc: "pointer slice generic param", input: "Reader[*[]string]", want: Type{Name: "Reader", Params: []string{"*[]string"}}}, + {desc: "pointer slice value generic param", input: "Reader[[]*string]", want: Type{Name: "Reader", Params: []string{"[]*string"}}}, + {desc: "array generic param", input: "Reader[[1]string]", want: Type{Name: "Reader", Params: []string{"[1]string"}}}, + {desc: "pointer array generic param", input: "Reader[*[1]string]", want: Type{Name: "Reader", Params: []string{"*[1]string"}}}, + {desc: "pointer array value generic param", input: "Reader[[1]*string]", want: Type{Name: "Reader", Params: []string{"[1]*string"}}}, + {desc: "chan generic param", input: "Reader[chan error]", want: Type{Name: "Reader", Params: []string{"chan error"}}}, + {desc: "receiver chan generic param", input: "Reader[<-chan error]", want: Type{Name: "Reader", Params: []string{"<-chan error"}}}, + {desc: "send chan generic param", input: "Reader[chan<- error]", want: Type{Name: "Reader", Params: []string{"chan<- error"}}}, + {desc: "pointer chan generic param", input: "Reader[*chan error]", want: Type{Name: "Reader", Params: []string{"*chan error"}}}, + {desc: "func generic param", input: "Reader[func() string]", want: Type{Name: "Reader", Params: []string{"func() string"}}}, + {desc: "one arg func generic param", input: "Reader[func(a int) string]", want: Type{Name: "Reader", Params: []string{"func(a int) string"}}}, + {desc: "two arg one type func generic param", input: "Reader[func(a, b int) string]", want: Type{Name: "Reader", Params: []string{"func(a, b int) string"}}}, + {desc: "three arg one type func generic param", input: "Reader[func(a, b, c int) string]", want: Type{Name: "Reader", Params: []string{"func(a, b, c int) string"}}}, + {desc: "three arg two types func generic param", input: "Reader[func(a, b string, c int) string]", want: Type{Name: "Reader", Params: []string{"func(a, b string, c int) string"}}}, + {desc: "three arg three types func generic param", input: "Reader[func(a bool, b string, c int) string]", want: Type{Name: "Reader", Params: []string{"func(a bool, b string, c int) string"}}}, + // don't need support for generics on the function type itself; function types must have no type parameters + // https://cs.opensource.google/go/go/+/master:src/go/parser/parser.go;l=1048;drc=cafb49ac731f862f386862d64b27b8314eeb2909 + } + for _, tt := range cases { + tt := tt + t.Run(tt.desc, func(t *testing.T) { + t.Parallel() + + typ, err := parseType(tt.input) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("unexpected error: %s", err) + } + if typ.Name != tt.want.Name { + t.Errorf("wanted ID %q, got %q", tt.want.Name, typ.Name) + } + if len(typ.Params) != len(tt.want.Params) { + t.Errorf("wanted %d params, got %d: %v", len(tt.want.Params), len(typ.Params), typ.Params) + } + for pos, param := range typ.Params { + if param != tt.want.Params[pos] { + t.Errorf("expected param %d to be %q, got %q: %v", pos, tt.want.Params[pos], param, typ.Params) + } + } + }) + } +} diff --git a/testdata/interfaces.go b/testdata/interfaces.go index 3167166..3b9f25b 100644 --- a/testdata/interfaces.go +++ b/testdata/interfaces.go @@ -48,6 +48,42 @@ type Interface3 interface { Method3(arg1, arg2 bool) (result1, result2 bool) } +// GenericInterface1 is a dummy interface to test the program output. This +// interface tests generation of generic interfaces with the specified type +// parameters. +type GenericInterface1[Type any] interface { + // Method1 is the first method of GenericInterface1. + Method1() Type + // Method2 is the second method of GenericInterface1. + Method2(Type) + // Method3 is the third method of GenericInterface1. + Method3(Type) Type +} + +// GenericInterface2 is a dummy interface to test the program output. This +// interface tests generation of generic interfaces with the specified type +// parameters. +type GenericInterface2[Type1 any, Type2 comparable] interface { + // Method1 is the first method of GenericInterface2. + Method1() (Type1, Type2) + // Method2 is the second method of GenericInterface2. + Method2(Type1, Type2) + // Method3 is the third method of GenericInterface2. + Method3(Type1) Type2 +} + +// GenericInterface3 is a dummy interface to test the program output. This +// interface tests generation of generic interfaces with repeated type +// parameters. +type GenericInterface3[Type1, Type2 any] interface { + // Method1 is the first method of GenericInterface3. + Method1() (Type1, Type2) + // Method2 is the second method of GenericInterface3. + Method2(Type1, Type2) + // Method3 is the third method of GenericInterface3. + Method3(Type1) Type2 +} + // Interface1Output is the expected output generated from reflecting on // Interface1, provided that the receiver is equal to 'r *Receiver'. var Interface1Output = `// Method1 is the first method of Interface1. @@ -172,3 +208,63 @@ func (arg3 *Implemented) Method2(arg1 string, arg2 int) (_ error) { } ` + +// GenericInterface1Output is the expected output generated from reflecting on +// GenericInterface1, provided that the receiver is equal to 'r *Receiver' and +// it was generated with the type parameters [string]. +var GenericInterface1Output = `// Method1 is the first method of GenericInterface1. +func (r *Receiver) Method1() string { + panic("not implemented") // TODO: Implement +} + +// Method2 is the second method of GenericInterface1. +func (r *Receiver) Method2(_ string) { + panic("not implemented") // TODO: Implement +} + +// Method3 is the third method of GenericInterface1. +func (r *Receiver) Method3(_ string) string { + panic("not implemented") // TODO: Implement +} + +` + +// GenericInterface2Output is the expected output generated from reflecting on +// GenericInterface2, provided that the receiver is equal to 'r *Receiver' and +// it was generated with the type parameters [string, bool]. +var GenericInterface2Output = `// Method1 is the first method of GenericInterface2. +func (r *Receiver) Method1() (string, bool) { + panic("not implemented") // TODO: Implement +} + +// Method2 is the second method of GenericInterface2. +func (r *Receiver) Method2(_ string, _ bool) { + panic("not implemented") // TODO: Implement +} + +// Method3 is the third method of GenericInterface2. +func (r *Receiver) Method3(_ string) bool { + panic("not implemented") // TODO: Implement +} + +` + +// GenericInterface3Output is the expected output generated from reflecting on +// GenericInterface3, provided that the receiver is equal to 'r *Receiver' and +// it was generated with the type parameters [string, bool]. +var GenericInterface3Output = `// Method1 is the first method of GenericInterface3. +func (r *Receiver) Method1() (string, bool) { + panic("not implemented") // TODO: Implement +} + +// Method2 is the second method of GenericInterface3. +func (r *Receiver) Method2(_ string, _ bool) { + panic("not implemented") // TODO: Implement +} + +// Method3 is the third method of GenericInterface3. +func (r *Receiver) Method3(_ string) bool { + panic("not implemented") // TODO: Implement +} + +`