diff --git a/Makefile b/Makefile index 00c855a..378d046 100644 --- a/Makefile +++ b/Makefile @@ -33,7 +33,7 @@ lint: ./bin/golangci-lint .PHONY: test: - go test -race ./... + go test -race -v ./... .PHONY: generate: ./bin/gowrap ./bin/minimock diff --git a/generator/generics.go b/generator/generics.go index 780956a..7d22c27 100644 --- a/generator/generics.go +++ b/generator/generics.go @@ -2,6 +2,7 @@ package generator import ( "go/ast" + "go/token" "strings" ) @@ -78,42 +79,75 @@ func (g genericTypes) buildVars() (string, string) { } func buildGenericTypesFromSpec(ts *ast.TypeSpec, allTypes []*ast.TypeSpec, typesPrefix string) (types genericTypes) { - if ts.TypeParams != nil { - for _, param := range ts.TypeParams.List { - if param != nil { - var typeIdentifier string - switch t := param.Type.(type) { - case *ast.Ident: - prefix := "" - if typesPrefix != "" { - for _, at := range allTypes { - if at.Name.Name == t.Name { - prefix = typesPrefix + "." - break - } - } - } - - typeIdentifier = prefix + t.Name - case *ast.SelectorExpr: - typeIdentifier = t.X.(*ast.Ident).Name + "." + t.Sel.Name - default: - panic("unsupported generic type") - } + if ts.TypeParams == nil { + return + } - var paramNames []string - for _, name := range param.Names { - if name != nil { - paramNames = append(paramNames, name.Name) - } - } - types = append(types, genericType{ - Type: typeIdentifier, - Names: paramNames, - }) + for _, param := range ts.TypeParams.List { + if param == nil { + continue + } + + typeIdentifier := parseGenericType(param.Type, allTypes, typesPrefix) + + var paramNames []string + for _, name := range param.Names { + if name != nil { + paramNames = append(paramNames, name.Name) } } + types = append(types, genericType{ + Type: typeIdentifier, + Names: paramNames, + }) + } + + return +} + +func parseBinaryExpr(expr *ast.BinaryExpr, allTypes []*ast.TypeSpec, typesPrefix string) string { + if expr.Op != token.OR { + return "" } + + leftPart := parseGenericType(expr.X, allTypes, typesPrefix) + rightPart := parseGenericType(expr.Y, allTypes, typesPrefix) + + return leftPart + " | " + rightPart +} + +func parseGenericType(exprPart ast.Expr, allTypes []*ast.TypeSpec, typesPrefix string) (part string) { + switch expr := exprPart.(type) { + case *ast.Ident: + prefix := getPrefix(expr, allTypes, typesPrefix) + part = prefix + expr.Name + case *ast.SelectorExpr: + part = expr.X.(*ast.Ident).Name + "." + expr.Sel.Name + case *ast.UnaryExpr: + if expr.Op == token.TILDE { + if id, ok := expr.X.(*ast.Ident); ok { + part = "~" + id.Name + } + } + case *ast.BinaryExpr: + part = parseBinaryExpr(expr, allTypes, typesPrefix) + } + + return +} + +func getPrefix(t *ast.Ident, allTypes []*ast.TypeSpec, typesPrefix string) (prefix string) { + if typesPrefix == "" { + return + } + + for _, at := range allTypes { + if at.Name.Name == t.Name { + prefix = typesPrefix + "." + break + } + } + return } diff --git a/generator/generics_test.go b/generator/generics_test.go index d985179..a688fb4 100644 --- a/generator/generics_test.go +++ b/generator/generics_test.go @@ -2,6 +2,7 @@ package generator import ( "go/ast" + "go/token" "reflect" "testing" ) @@ -228,6 +229,140 @@ func Test_buildGenericTypesFromSpec(t *testing.T) { }, }, }, + { + name: "build generic types with type approximation", + args: args{ + ts: &ast.TypeSpec{ + TypeParams: &ast.FieldList{ + List: []*ast.Field{ + { + Type: &ast.UnaryExpr{ + Op: token.TILDE, + X: &ast.Ident{ + Name: "float64", + }, + }, + Names: []*ast.Ident{ + { + Name: "T", + }, + }, + }, + }, + }, + }, + allTypes: []*ast.TypeSpec{ + { + Name: &ast.Ident{ + Name: "Bar", + }, + }, + }, + typesPrefix: "prefix", + }, + wantTypes: genericTypes{ + { + Type: "~float64", + Names: []string{"T"}, + }, + }, + }, + { + name: "build generic types from union", + args: args{ + ts: &ast.TypeSpec{ + TypeParams: &ast.FieldList{ + List: []*ast.Field{ + { + Type: &ast.BinaryExpr{ + X: &ast.BinaryExpr{ + X: &ast.Ident{ + Name: "int", + }, + Op: token.OR, + Y: &ast.SelectorExpr{ + X: &ast.Ident{ + Name: "pkg", + }, + Sel: &ast.Ident{ + Name: "Baz", + }, + }, + }, + Op: token.OR, + Y: &ast.Ident{ + Name: "Bar", + }, + }, + Names: []*ast.Ident{ + { + Name: "B", + }, + }, + }, + }, + }, + }, + allTypes: []*ast.TypeSpec{ + { + Name: &ast.Ident{ + Name: "Bar", + }, + }, + }, + typesPrefix: "prefix", + }, + wantTypes: genericTypes{ + { + Type: "int | pkg.Baz | prefix.Bar", + Names: []string{"B"}, + }, + }, + }, + { + name: "build generic types from union with type approximation", + args: args{ + ts: &ast.TypeSpec{ + TypeParams: &ast.FieldList{ + List: []*ast.Field{ + { + Type: &ast.BinaryExpr{ + X: &ast.Ident{ + Name: "Bar", + }, + Op: token.OR, + Y: &ast.UnaryExpr{ + X: &ast.Ident{ + Name: "float32", + }, + Op: token.TILDE, + }, + }, + Names: []*ast.Ident{ + { + Name: "T", + }, + }, + }, + }, + }, + }, + allTypes: []*ast.TypeSpec{ + { + Name: &ast.Ident{ + Name: "Bar", + }, + }, + }, + typesPrefix: "prefix", + }, + wantTypes: genericTypes{ + { + Type: "prefix.Bar | ~float32", + Names: []string{"T"}, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {