Skip to content

Commit

Permalink
Fix transform bug with result type oneof attribute (#3115)
Browse files Browse the repository at this point in the history
Upgrade to Go 1.19
Remove use of deprecated io/ioutil package
  • Loading branch information
raphael committed Aug 6, 2022
1 parent 42e10d7 commit bfc79af
Show file tree
Hide file tree
Showing 20 changed files with 167 additions and 104 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Expand Up @@ -7,10 +7,10 @@ jobs:
runs-on: ubuntu-latest
steps:

- name: Set up Go 1.18
- name: Set up Go 1.19
uses: actions/setup-go@v3.2.1
with:
go-version: 1.18
go-version: 1.19
id: go

- name: Check out code into the Go module directory
Expand All @@ -34,10 +34,10 @@ jobs:
runs-on: windows-latest
steps:

- name: Set up Go 1.18
- name: Set up Go 1.19
uses: actions/setup-go@v3.2.1
with:
go-version: 1.18
go-version: 1.19
id: go

- name: Check out code into the Go module directory
Expand Down
5 changes: 2 additions & 3 deletions cmd/goa/gen.go
Expand Up @@ -6,7 +6,6 @@ import (
"go/build"
"go/parser"
"go/token"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
Expand Down Expand Up @@ -67,7 +66,7 @@ func NewGenerator(cmd string, path, output string) *Generator {
}
}
for _, gof := range pkg.GoFiles {
if bs, err := ioutil.ReadFile(gof); err == nil {
if bs, err := os.ReadFile(gof); err == nil {
if f, err := parser.ParseFile(fset, "", string(bs), parser.ImportsOnly); err == nil {
for _, s := range f.Imports {
matches := p.FindStringSubmatch(s.Path.Value)
Expand Down Expand Up @@ -106,7 +105,7 @@ func (g *Generator) Write(debug bool) error {
if cwd, err := os.Getwd(); err != nil {
wd = cwd
}
tmp, err := ioutil.TempDir(wd, "goa")
tmp, err := os.MkdirTemp(wd, "goa")
if err != nil {
return err
}
Expand Down
7 changes: 3 additions & 4 deletions codegen/file.go
Expand Up @@ -9,7 +9,6 @@ import (
"go/scanner"
"go/token"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -138,7 +137,7 @@ func finalizeGoSource(path string) error {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
if err != nil {
content, _ := ioutil.ReadFile(path)
content, _ := os.ReadFile(path)
var buf bytes.Buffer
scanner.PrintError(&buf, err)
return fmt.Errorf("%s\n========\nContent:\n%s", buf.String(), content)
Expand Down Expand Up @@ -169,7 +168,7 @@ func finalizeGoSource(path string) error {
w.Close()

// Format code using goimport standard
bs, err := ioutil.ReadFile(path)
bs, err := os.ReadFile(path)
if err != nil {
return err
}
Expand All @@ -181,5 +180,5 @@ func finalizeGoSource(path string) error {
if err != nil {
return err
}
return ioutil.WriteFile(path, bs, os.ModePerm)
return os.WriteFile(path, bs, os.ModePerm)
}
3 changes: 1 addition & 2 deletions codegen/generator/generate.go
@@ -1,7 +1,6 @@
package generator

import (
"io/ioutil"
"os"
"path/filepath"
"sort"
Expand Down Expand Up @@ -36,7 +35,7 @@ func Generate(dir, cmd string) (outputs []string, err1 error) {
}

// We create a temporary Go file to make sure the directory is a valid Go package
dummy, err := ioutil.TempFile(path, "temp.*.go")
dummy, err := os.CreateTemp(path, "temp.*.go")
if err != nil {
return nil, err
}
Expand Down
9 changes: 6 additions & 3 deletions codegen/go_transform.go
Expand Up @@ -348,6 +348,10 @@ func transformMap(source, target *expr.Map, sourceVar, targetVar string, newVar
}

// transformUnion generates Go code to transform source union to target union.
//
// Note: transport to/from service transforms are always object to union or
// union to object. The only case a transform is union to union is when
// converting a projected type from/to a service type.
func transformUnion(source, target *expr.AttributeExpr, sourceVar, targetVar string, newVar bool, ta *TransformAttrs) (string, error) {
if expr.IsObject(target.Type) {
return transformUnionToObject(source, target, sourceVar, targetVar, newVar, ta)
Expand All @@ -365,7 +369,7 @@ func transformUnion(source, target *expr.AttributeExpr, sourceVar, targetVar str
}
sourceTypeRefs := make([]string, len(srcUnion.Values))
for i, st := range srcUnion.Values {
sourceTypeRefs[i] = ta.TargetCtx.Scope.Ref(st.Attribute, ta.TargetCtx.Pkg(st.Attribute))
sourceTypeRefs[i] = ta.TargetCtx.Scope.Ref(st.Attribute, ta.SourceCtx.Pkg(st.Attribute))
}
targetTypeNames := make([]string, len(tgtUnion.Values))
for i, tt := range tgtUnion.Values {
Expand Down Expand Up @@ -661,8 +665,7 @@ for key, val := range {{ .SourceVar }} {
{{ end }}switch actual := {{ .SourceVar }}.(type) {
{{- range $i, $ref := .SourceTypeRefs }}
case {{ $ref }}:
{{ transformAttribute (index $.SourceTypes $i).Attribute (index $.TargetTypes $i).Attribute "actual" "val" true $.TransformAttrs -}}
{{ $.TargetVar }} = {{ $.TargetTypeName }}{ Value: val }
{{- transformAttribute (index $.SourceTypes $i).Attribute (index $.TargetTypes $i).Attribute "actual" $.TargetVar false $.TransformAttrs -}}
{{- end }}
}
`
Expand Down
31 changes: 23 additions & 8 deletions codegen/go_transform_union_test.go
Expand Up @@ -13,12 +13,13 @@ func TestGoTransformUnion(t *testing.T) {
scope = NewNameScope()

// types to test
unionString = root.UserType("Container").Attribute().Find("UnionString").Find("UnionString")
unionString2 = root.UserType("Container").Attribute().Find("UnionString2").Find("UnionString2")
unionStringInt = root.UserType("Container").Attribute().Find("UnionStringInt").Find("UnionStringInt")
unionSomeType = root.UserType("Container").Attribute().Find("UnionSomeType").Find("UnionSomeType")
userType = &expr.AttributeExpr{Type: root.UserType("UnionUserType")}
defaultCtx = NewAttributeContext(false, false, true, "", scope)
unionString = root.UserType("Container").Attribute().Find("UnionString").Find("UnionString")
unionString2 = root.UserType("Container").Attribute().Find("UnionString2").Find("UnionString2")
unionStringInt = root.UserType("Container").Attribute().Find("UnionStringInt").Find("UnionStringInt")
unionStringInt2 = root.UserType("Container").Attribute().Find("UnionStringInt2").Find("UnionStringInt2")
unionSomeType = root.UserType("Container").Attribute().Find("UnionSomeType").Find("UnionSomeType")
userType = &expr.AttributeExpr{Type: root.UserType("UnionUserType")}
defaultCtx = NewAttributeContext(false, false, true, "", scope)
)
tc := []struct {
Name string
Expand All @@ -27,6 +28,7 @@ func TestGoTransformUnion(t *testing.T) {
Expected string
}{
{"UnionString to UnionString2", unionString, unionString2, unionToUnionCode},
{"UnionStringInt to UnionStringInt2", unionStringInt, unionStringInt2, unionMultiToUnionMultiCode},

{"UnionString to User Type", unionString, userType, unionStringToUserTypeCode},
{"UnionStringInt to User Type", unionStringInt, userType, unionStringIntToUserTypeCode},
Expand Down Expand Up @@ -89,8 +91,21 @@ const unionToUnionCode = `func transform() {
var target *UnionString2
switch actual := source.(type) {
case UnionStringString:
val := UnionString2String(actual)
target = UnionString2{Value: val}
target = UnionString2String(actual)
}
}
`

const unionMultiToUnionMultiCode = `func transform() {
var target *UnionStringInt2
switch actual := source.(type) {
case UnionStringIntString:
target = UnionStringInt2String(actual)
case UnionStringIntInt:
target = UnionStringInt2Int(actual)
}
}
`
Expand Down
104 changes: 66 additions & 38 deletions codegen/service/service_data.go
Expand Up @@ -78,6 +78,8 @@ type (
// projectedTypes lists the types which uses pointers for all fields to
// define view specific validation logic.
projectedTypes []*ProjectedTypeData
// union methods that need to be defined in views package.
viewedUnionMethods []*UnionValueMethodData
// viewedResultTypes lists all the viewed method result types.
viewedResultTypes []*ViewedResultTypeData
// unionValueMethods lists the methods used to define union types.
Expand Down Expand Up @@ -538,19 +540,20 @@ func (s SchemesData) Append(d *SchemeData) SchemesData {
// It records the user types needed by the service definition in userTypes.
func (d ServicesData) analyze(service *expr.ServiceExpr) *Data {
var (
scope *codegen.NameScope
viewScope *codegen.NameScope
pkgName string
viewspkg string
types []*UserTypeData
errTypes []*UserTypeData
errorInits []*ErrorInitData
projTypes []*ProjectedTypeData
viewedRTs []*ViewedResultTypeData
seenErrors map[string]struct{}
seen map[string]struct{}
seenProj map[string]*ProjectedTypeData
seenViewed map[string]*ViewedResultTypeData
scope *codegen.NameScope
viewScope *codegen.NameScope
pkgName string
viewspkg string
types []*UserTypeData
errTypes []*UserTypeData
errorInits []*ErrorInitData
projTypes []*ProjectedTypeData
viewedUnionMeths []*UnionValueMethodData
viewedRTs []*ViewedResultTypeData
seenErrors map[string]struct{}
seen map[string]struct{}
seenProj map[string]*ProjectedTypeData
seenViewed map[string]*ViewedResultTypeData
)
{
scope = codegen.NewNameScope()
Expand Down Expand Up @@ -593,7 +596,9 @@ func (d ServicesData) analyze(service *expr.ServiceExpr) *Data {
if _, ok := m.Result.Type.(*expr.ResultTypeExpr); ok {
// collect projected types for the corresponding result type
projected := expr.DupAtt(m.Result)
projTypes = append(projTypes, collectProjectedTypes(projected, m.Result, viewspkg, scope, viewScope, seenProj)...)
types, umeths := collectProjectedTypes(projected, m.Result, viewspkg, scope, viewScope, seenProj)
projTypes = append(projTypes, types...)
viewedUnionMeths = append(viewedUnionMeths, umeths...)
}
for _, er := range m.Errors {
recordError(er)
Expand Down Expand Up @@ -724,23 +729,24 @@ func (d ServicesData) analyze(service *expr.ServiceExpr) *Data {

varName := codegen.Goify(service.Name, false)
data := &Data{
Name: service.Name,
Description: desc,
VarName: varName,
PathName: codegen.SnakeCase(varName),
StructName: codegen.Goify(service.Name, true),
PkgName: pkgName,
ViewsPkg: viewspkg,
Methods: methods,
Schemes: schemes,
Scope: scope,
ViewScope: viewScope,
errorTypes: errTypes,
errorInits: errorInits,
userTypes: types,
projectedTypes: projTypes,
viewedResultTypes: viewedRTs,
unionValueMethods: ms,
Name: service.Name,
Description: desc,
VarName: varName,
PathName: codegen.SnakeCase(varName),
StructName: codegen.Goify(service.Name, true),
PkgName: pkgName,
ViewsPkg: viewspkg,
Methods: methods,
Schemes: schemes,
Scope: scope,
ViewScope: viewScope,
errorTypes: errTypes,
errorInits: errorInits,
userTypes: types,
projectedTypes: projTypes,
viewedUnionMethods: viewedUnionMeths,
viewedResultTypes: viewedRTs,
unionValueMethods: ms,
}
d[service.Name] = data

Expand Down Expand Up @@ -1149,8 +1155,8 @@ func BuildSchemeData(s *expr.SchemeExpr, m *expr.MethodExpr) *SchemeData {
// make use of views. We need to build projected types for all user types - not
// just result types - because user types make contain result types and thus may
// need to be marshalled in different ways depending on the view being used.
func collectProjectedTypes(projected, att *expr.AttributeExpr, viewspkg string, scope, viewScope *codegen.NameScope, seen map[string]*ProjectedTypeData) (data []*ProjectedTypeData) {
collect := func(projected, att *expr.AttributeExpr) []*ProjectedTypeData {
func collectProjectedTypes(projected, att *expr.AttributeExpr, viewspkg string, scope, viewScope *codegen.NameScope, seen map[string]*ProjectedTypeData) (data []*ProjectedTypeData, umeths []*UnionValueMethodData) {
collect := func(projected, att *expr.AttributeExpr) ([]*ProjectedTypeData, []*UnionValueMethodData) {
return collectProjectedTypes(projected, att, viewspkg, scope, viewScope, seen)
}
switch pt := projected.Type.(type) {
Expand All @@ -1171,22 +1177,44 @@ func collectProjectedTypes(projected, att *expr.AttributeExpr, viewspkg string,
pt.Rename(pt.Name() + "View")
// We recurse before building the projected type so that user types within
// a projected type is also converted to their respective projected types.
types := collect(pt.Attribute(), dt.Attribute())
types, ms := collect(pt.Attribute(), dt.Attribute())
pd := buildProjectedType(projected, att, viewspkg, scope, viewScope)
seen[dt.ID()] = pd
data = append(data, pd)
data = append(data, types...)
umeths = append(umeths, ms...)
case *expr.Array:
dt := att.Type.(*expr.Array)
data = append(data, collect(pt.ElemType, dt.ElemType)...)
types, ms := collect(pt.ElemType, dt.ElemType)
data = append(data, types...)
umeths = append(umeths, ms...)
case *expr.Map:
dt := att.Type.(*expr.Map)
data = append(data, collect(pt.KeyType, dt.KeyType)...)
data = append(data, collect(pt.ElemType, dt.ElemType)...)
types, ms := collect(pt.KeyType, dt.KeyType)
data = append(data, types...)
umeths = append(umeths, ms...)
types, ms = collect(pt.ElemType, dt.ElemType)
data = append(data, types...)
umeths = append(umeths, ms...)
case *expr.Object:
dt := att.Type.(*expr.Object)
for _, n := range *pt {
data = append(data, collect(n.Attribute, dt.Attribute(n.Name))...)
types, ms := collect(n.Attribute, dt.Attribute(n.Name))
data = append(data, types...)
umeths = append(umeths, ms...)
}
case *expr.Union:
dt := att.Type.(*expr.Union)
for i, n := range pt.Values {
types, ms := collect(n.Attribute, dt.Values[i].Attribute)
data = append(data, types...)
umeths = append(umeths, ms...)
}
for _, nat := range pt.Values {
umeths = append(umeths, &UnionValueMethodData{
Name: codegen.UnionValTypeName(pt.Name()),
TypeRef: scope.GoTypeRef(nat.Attribute),
})
}
}
return
Expand Down
10 changes: 10 additions & 0 deletions codegen/service/views.go
Expand Up @@ -49,6 +49,16 @@ func ViewsFile(genpkg string, service *expr.ServiceExpr) *codegen.File {
})
}

// Union methods
for _, m := range svc.viewedUnionMethods {
sections = append(sections, &codegen.SectionTemplate{
// addTypeDefSection(pathWithDefault(m.Loc, svcPath), "~"+m.TypeRef+"."+m.Name, &codegen.SectionTemplate{
Name: "viewed-union-value-method",
Source: unionValueMethodT,
Data: m,
})
}

// generate a map for result types with view name as key and the fields
// rendered in the view as value.
var (
Expand Down
7 changes: 7 additions & 0 deletions codegen/testdata/union_dsls.go
Expand Up @@ -26,6 +26,12 @@ var TestUnionDSL = func() {
Attribute("Int", Int)
})
})
UnionStringInt2 = Type("UnionStringInt2", func() {
OneOf("UnionStringInt2", func() {
Attribute("String", String)
Attribute("Int", Int)
})
})
UnionSomeType = Type("UnionSomeType", func() {
OneOf("UnionSomeType", func() {
Attribute("SomeType", SomeType)
Expand All @@ -36,6 +42,7 @@ var TestUnionDSL = func() {
Attribute("UnionString", UnionString)
Attribute("UnionString2", UnionString2)
Attribute("UnionStringInt", UnionStringInt)
Attribute("UnionStringInt2", UnionStringInt2)
Attribute("UnionSomeType", UnionSomeType)
})

Expand Down

0 comments on commit bfc79af

Please sign in to comment.