Skip to content

Commit

Permalink
feat: detection type args that local or other package
Browse files Browse the repository at this point in the history
  • Loading branch information
mackee committed Jun 2, 2024
1 parent 8dde16f commit 615ad6c
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 27 deletions.
124 changes: 108 additions & 16 deletions _example/group.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions _example/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/go-sql-driver/mysql"
"github.com/mackee/go-sqlla/_example/id"
)

//go:generate go run ../cmd/sqlla/main.go
Expand All @@ -14,10 +15,11 @@ type GroupID uint64
//sqlla:table group
//genddl:table group
type Group struct {
ID GroupID `db:"id,primarykey,autoincrement"`
Name string `db:"name"`
LeaderUserID UserId `db:"leader_user_id"`
SubLeaderUserID sql.Null[UserId] `db:"sub_leader_user_id"`
ID GroupID `db:"id,primarykey,autoincrement"`
Name string `db:"name"`
LeaderUserID UserId `db:"leader_user_id"`
SubLeaderUserID sql.Null[UserId] `db:"sub_leader_user_id"`
ChildGroupID sql.Null[id.GroupID] `db:"child_group_id"`

CreatedAt time.Time `db:"created_at"`
UpdatedAt mysql.NullTime `db:"updated_at"`
Expand Down
3 changes: 3 additions & 0 deletions _example/id/id.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package id

type GroupID uint64
2 changes: 1 addition & 1 deletion generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func WriteCode(w io.Writer, table *Table) error {
i := 0
for scanner.Scan() {
i++
fmt.Printf("%05d: %s\n", i, scanner.Text())
// fmt.Printf("%05d: %s\n", i, scanner.Text())
}
bs, err := format.Source(buf.Bytes())
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/mackee/go-sqlla/v2
go 1.20

require (
github.com/Masterminds/goutils v1.1.1
github.com/go-sql-driver/mysql v1.6.0
github.com/pkg/errors v0.9.1
github.com/serenize/snaker v0.0.0-20201027110005-a7ad2135616e
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
Expand Down
26 changes: 20 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,10 @@ func toTable(tablePkg *types.Package, annotationComment string, gd *ast.GenDecl,
table := new(Table)
table.Package = tablePkg
table.PackageName = tablePkg.Name()
table.additionalPackagesMap = make(map[string]struct{})

table.TableName = trimAnnotation(annotationComment)
qualifier := types.RelativeTo(tablePkg)

spec := gd.Specs[0]
ts, ok := spec.(*ast.TypeSpec)
Expand Down Expand Up @@ -140,14 +142,26 @@ func toTable(tablePkg *types.Package, annotationComment string, gd *ast.GenDecl,
}
baseTypeName = typeName
if _, ok := supportedGenericsTypes[typeName]; ok {
tps := nt.TypeParams()
if tps == nil {
tas := nt.TypeArgs()
if tas == nil {
return nil, fmt.Errorf("toTable: has not type params: table=%s, field=%s", table.TableName, columnName)
}
tpsStr := make([]string, tps.Len())
for i := 0; i < tps.Len(); i++ {
tp := tps.At(i)
tpsStr[i] = tp.String()
tpsStr := make([]string, tas.Len())
for i := 0; i < tas.Len(); i++ {
ta := tas.At(i)
switch ata := ta.(type) {
case *types.Named:
tn := ata.Obj()
tpsStr[i] = tn.Id()
if qualifier(tn.Pkg()) != "" {
tpsStr[i] = tn.Pkg().Name() + "." + tn.Id()
table.additionalPackagesMap[tn.Pkg().Path()] = struct{}{}
}
case *types.Basic:
tpsStr[i] = ata.Name()
default:
return nil, fmt.Errorf("toTable: unsupported type param: table=%s, field=%s, type=%s", table.TableName, columnName, ta.String())
}
}
typeParameter = strings.Join(tpsStr, ",")
} else if _, ok := supportedNonPrimitiveTypes[typeName]; !ok {
Expand Down

0 comments on commit 615ad6c

Please sign in to comment.