Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support generic interface generation #175

Merged
merged 1 commit into from
Oct 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
module github.com/matryer/moq

go 1.14
go 1.18

require (
github.com/pmezard/go-difflib v1.0.0
golang.org/x/tools v0.1.10
)

require (
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
)
21 changes: 0 additions & 21 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20=
golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
14 changes: 10 additions & 4 deletions internal/registry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,23 @@ func (r Registry) SrcPkgName() string {

// LookupInterface returns the underlying interface definition of the
// given interface name.
func (r Registry) LookupInterface(name string) (*types.Interface, error) {
func (r Registry) LookupInterface(name string) (*types.Interface, *types.TypeParamList, error) {
obj := r.SrcPkg().Scope().Lookup(name)
if obj == nil {
return nil, fmt.Errorf("interface not found: %s", name)
return nil, nil, fmt.Errorf("interface not found: %s", name)
}

if !types.IsInterface(obj.Type()) {
return nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type())
return nil, nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type())
}

return obj.Type().Underlying().(*types.Interface).Complete(), nil
var tparams *types.TypeParamList
named, ok := obj.Type().(*types.Named)
if ok {
tparams = named.TypeParams()
}

return obj.Type().Underlying().(*types.Interface).Complete(), tparams, nil
}

// MethodScope returns a new MethodScope.
Expand Down
40 changes: 36 additions & 4 deletions internal/template/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,22 @@ import (
{{- if not $.SkipEnsure -}}
// Ensure, that {{.MockName}} does implement {{$.SrcPkgQualifier}}{{.InterfaceName}}.
// If this is not the case, regenerate this file with moq.
var _ {{$.SrcPkgQualifier}}{{.InterfaceName}} = &{{.MockName}}{}
var _ {{$.SrcPkgQualifier}}{{.InterfaceName -}}
{{- if .TypeParams }}[
{{- range $index, $param := .TypeParams}}
{{- if $index}}, {{end -}}
{{if $param.Constraint}}{{$param.Constraint.String}}{{else}}{{$param.TypeString}}{{end}}
{{- end -}}
]
{{- end }} = &{{.MockName}}
{{- if .TypeParams }}[
{{- range $index, $param := .TypeParams}}
{{- if $index}}, {{end -}}
{{if $param.Constraint}}{{$param.Constraint.String}}{{else}}{{$param.TypeString}}{{end}}
{{- end -}}
]
{{- end -}}
{}
{{- end}}

// {{.MockName}} is a mock implementation of {{$.SrcPkgQualifier}}{{.InterfaceName}}.
Expand All @@ -68,7 +83,12 @@ var _ {{$.SrcPkgQualifier}}{{.InterfaceName}} = &{{.MockName}}{}
// // and then make assertions.
//
// }
type {{.MockName}} struct {
type {{.MockName}}
{{- if .TypeParams -}}
[{{- range $index, $param := .TypeParams}}
{{- if $index}}, {{end}}{{$param.Name | Exported}} {{$param.TypeString}}
{{- end -}}]
{{- end }} struct {
{{- range .Methods}}
// {{.Name}}Func mocks the {{.Name}} method.
{{.Name}}Func func({{.ArgList}}) {{.ReturnArgTypeList}}
Expand All @@ -91,7 +111,13 @@ type {{.MockName}} struct {
}
{{range .Methods}}
// {{.Name}} calls {{.Name}}Func.
func (mock *{{$mock.MockName}}) {{.Name}}({{.ArgList}}) {{.ReturnArgTypeList}} {
func (mock *{{$mock.MockName}}
{{- if $mock.TypeParams -}}
[{{- range $index, $param := $mock.TypeParams}}
{{- if $index}}, {{end}}{{$param.Name | Exported}}
{{- end -}}]
{{- end -}}
) {{.Name}}({{.ArgList}}) {{.ReturnArgTypeList}} {
{{- if not $.StubImpl}}
if mock.{{.Name}}Func == nil {
panic("{{$mock.MockName}}.{{.Name}}Func: method is nil but {{$mock.InterfaceName}}.{{.Name}} was just called")
Expand Down Expand Up @@ -134,7 +160,13 @@ func (mock *{{$mock.MockName}}) {{.Name}}({{.ArgList}}) {{.ReturnArgTypeList}} {
// {{.Name}}Calls gets all the calls that were made to {{.Name}}.
// Check the length with:
// len(mocked{{$mock.InterfaceName}}.{{.Name}}Calls())
func (mock *{{$mock.MockName}}) {{.Name}}Calls() []struct {
func (mock *{{$mock.MockName}}
{{- if $mock.TypeParams -}}
[{{- range $index, $param := $mock.TypeParams}}
{{- if $index}}, {{end}}{{$param.Name | Exported}}
{{- end -}}]
{{- end -}}
) {{.Name}}Calls() []struct {
{{- range .Params}}
{{.Name | Exported}} {{.TypeString}}
{{- end}}
Expand Down
7 changes: 7 additions & 0 deletions internal/template/template_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package template

import (
"fmt"
"go/types"
"strings"

"github.com/matryer/moq/internal/registry"
Expand Down Expand Up @@ -33,6 +34,7 @@ func (d Data) MocksSomeMethod() bool {
type MockData struct {
InterfaceName string
MockName string
TypeParams []TypeParamData
Methods []MethodData
}

Expand Down Expand Up @@ -87,6 +89,11 @@ func (m MethodData) ReturnArgNameList() string {
return strings.Join(params, ", ")
}

type TypeParamData struct {
ParamData
Constraint types.Type
}

// ParamData is the data which represents a parameter to some method of
// an interface.
type ParamData struct {
Expand Down
41 changes: 40 additions & 1 deletion pkg/moq/moq.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package moq
import (
"bytes"
"errors"
"go/token"
"go/types"
"io"
"strings"
Expand Down Expand Up @@ -57,7 +58,7 @@ func (m *Mocker) Mock(w io.Writer, namePairs ...string) error {
mocks := make([]template.MockData, len(namePairs))
for i, np := range namePairs {
name, mockName := parseInterfaceName(np)
iface, err := m.registry.LookupInterface(name)
iface, tparams, err := m.registry.LookupInterface(name)
if err != nil {
return err
}
Expand All @@ -71,6 +72,7 @@ func (m *Mocker) Mock(w io.Writer, namePairs ...string) error {
InterfaceName: name,
MockName: mockName,
Methods: methods,
TypeParams: m.typeParams(tparams),
}
}

Expand Down Expand Up @@ -110,6 +112,43 @@ func (m *Mocker) Mock(w io.Writer, namePairs ...string) error {
return nil
}

func (m *Mocker) typeParams(tparams *types.TypeParamList) []template.TypeParamData {
var tpd []template.TypeParamData
if tparams == nil {
return tpd
}

tpd = make([]template.TypeParamData, tparams.Len())

scope := m.registry.MethodScope()
for i := 0; i < len(tpd); i++ {
tp := tparams.At(i)
typeParam := types.NewParam(token.Pos(i), tp.Obj().Pkg(), tp.Obj().Name(), tp.Constraint())
tpd[i] = template.TypeParamData{
ParamData: template.ParamData{Var: scope.AddVar(typeParam, "")},
Constraint: explicitConstraintType(typeParam),
}
}

return tpd
}

func explicitConstraintType(typeParam *types.Var) (t types.Type) {
underlying := typeParam.Type().Underlying().(*types.Interface)
// check if any of the embedded types is either a basic type or a union,
// because the generic type has to be an alias for one of those types then
for j := 0; j < underlying.NumEmbeddeds(); j++ {
t := underlying.EmbeddedType(j)
switch t := t.(type) {
case *types.Basic:
return t
case *types.Union: // only unions of basic types are allowed, so just take the first one as a valid type constraint
return t.Term(0).Type()
}
}
return nil
}

func (m *Mocker) methodData(f *types.Func) template.MethodData {
sig := f.Type().(*types.Signature)

Expand Down
6 changes: 6 additions & 0 deletions pkg/moq/moq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,12 @@ func TestMockGolden(t *testing.T) {
interfaces: []string{"ShadowTypes"},
goldenFile: filepath.Join("testpackages/shadowtypes", "shadowtypes_moq.golden.go"),
},
{
name: "Generics",
cfg: Config{SrcDir: "testpackages/generics"},
interfaces: []string{"GenericStore1", "GenericStore2", "AliasStore"},
goldenFile: filepath.Join("testpackages/generics", "generics_moq.golden.go"),
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
Expand Down
32 changes: 32 additions & 0 deletions pkg/moq/testpackages/generics/generics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package generics

import (
"context"
"fmt"
)

type GenericStore1[T Key1, S any] interface {
Get(ctx context.Context, id T) (S, error)
Create(ctx context.Context, id T, value S) error
}

type GenericStore2[T Key2, S any] interface {
Get(ctx context.Context, id T) (S, error)
Create(ctx context.Context, id T, value S) error
}

type AliasStore GenericStore1[KeyImpl, bool]

type Key1 interface {
fmt.Stringer
}

type Key2 interface {
~[]byte | string
}

type KeyImpl []byte

func (x KeyImpl) String() string {
return string(x)
}
Loading