diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 79ce2d4..cdb7f2f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,10 +28,15 @@ jobs: cat $GITHUB_WORKSPACE/profile.cov_tmp | grep -v "_mock.go" > $GITHUB_WORKSPACE/profile.cov go build -race + - name: test examples + working-directory: _examples/status + run: go test -timeout=60s -race + - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: version: latest + skip-pkg-cache: true - name: install goveralls run: | diff --git a/_examples/status/go.mod b/_examples/status/go.mod new file mode 100644 index 0000000..2004da0 --- /dev/null +++ b/_examples/status/go.mod @@ -0,0 +1,24 @@ +module examples/status + +go 1.24 + +require ( + github.com/stretchr/testify v1.10.0 + modernc.org/sqlite v1.35.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 // indirect + golang.org/x/sys v0.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.61.13 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.8.2 // indirect +) diff --git a/_examples/status/go.sum b/_examples/status/go.sum new file mode 100644 index 0000000..2135ea3 --- /dev/null +++ b/_examples/status/go.sum @@ -0,0 +1,57 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0 h1:pVgRXcIictcr+lBQIFeiwuwtDIs4eL21OuM9nyAADmo= +golang.org/x/exp v0.0.0-20230315142452-642cacee5cc0/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/mod v0.19.0 h1:fEdghXQSo20giMthA7cd28ZC+jts4amQ3YMXiP5oMQ8= +golang.org/x/mod v0.19.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/tools v0.23.0 h1:SGsXPZ+2l4JsgaCKkx+FQ9YZ5XEtA1GZYuoDjenLjvg= +golang.org/x/tools v0.23.0/go.mod h1:pnu6ufv6vQkll6szChhK3C3L/ruaIv5eBeztNG8wtsI= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.24.4 h1:TFkx1s6dCkQpd6dKurBNmpo+G8Zl4Sq/ztJ+2+DEsh0= +modernc.org/cc/v4 v4.24.4/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.23.16 h1:Z2N+kk38b7SfySC1ZkpGLN2vthNJP1+ZzGZIlH7uBxo= +modernc.org/ccgo/v4 v4.23.16/go.mod h1:nNma8goMTY7aQZQNTyN9AIoJfxav4nvTnvKThAeMDdo= +modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= +modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= +modernc.org/gc/v2 v2.6.3 h1:aJVhcqAte49LF+mGveZ5KPlsp4tdGdAOT4sipJXADjw= +modernc.org/gc/v2 v2.6.3/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/libc v1.61.13 h1:3LRd6ZO1ezsFiX1y+bHd1ipyEHIJKvuprv0sLTBwLW8= +modernc.org/libc v1.61.13/go.mod h1:8F/uJWL/3nNil0Lgt1Dpz+GgkApWh04N3el3hxJcA6E= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.8.2 h1:cL9L4bcoAObu4NkxOlKWBWtNHIsnnACGF/TbqQ6sbcI= +modernc.org/memory v1.8.2/go.mod h1:ZbjSvMO5NQ1A2i3bWeDiVMxIorXwdClKE/0SZ+BMotU= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.35.0 h1:yQps4fegMnZFdphtzlfQTCNBWtS0CZv48pRpW3RFHRw= +modernc.org/sqlite v1.35.0/go.mod h1:9cr2sicr7jIaWTBKQmAxQLfBv9LL0su4ZTEV+utt3ic= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/_examples/status/status_enum.go b/_examples/status/status_enum.go index fca192b..a7bf370 100644 --- a/_examples/status/status_enum.go +++ b/_examples/status/status_enum.go @@ -3,6 +3,7 @@ package status import ( + "database/sql/driver" "fmt" ) @@ -26,6 +27,36 @@ func (e *Status) UnmarshalText(text []byte) error { return err } +// Value implements the driver.Valuer interface +func (e Status) Value() (driver.Value, error) { + return e.name, nil +} + +// Scan implements the sql.Scanner interface +func (e *Status) Scan(value interface{}) error { + if value == nil { + *e = StatusValues()[0] + return nil + } + + str, ok := value.(string) + if !ok { + if b, ok := value.([]byte); ok { + str = string(b) + } else { + return fmt.Errorf("invalid status value: %v", value) + } + } + + val, err := ParseStatus(str) + if err != nil { + return err + } + + *e = val + return nil +} + // ParseStatus converts string to status enum value func ParseStatus(v string) (Status, error) { diff --git a/_examples/status/status_test.go b/_examples/status/status_test.go index 0ff6dba..a5f4370 100644 --- a/_examples/status/status_test.go +++ b/_examples/status/status_test.go @@ -1,15 +1,19 @@ package status import ( + "database/sql" "encoding/json" "fmt" "testing" + _ "modernc.org/sqlite" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestStatus(t *testing.T) { + t.Run("basic", func(t *testing.T) { s := StatusActive assert.Equal(t, "active", s.String()) @@ -31,6 +35,71 @@ func TestStatus(t *testing.T) { assert.Equal(t, StatusInactive, d2.Status) }) + t.Run("sql", func(t *testing.T) { + s := StatusActive + + // test Value() method + v, err := s.Value() + require.NoError(t, err) + assert.Equal(t, "active", v) + + // test Scan from string + var s2 Status + err = s2.Scan("inactive") + require.NoError(t, err) + assert.Equal(t, StatusInactive, s2) + + // test Scan from []byte + err = s2.Scan([]byte("blocked")) + require.NoError(t, err) + assert.Equal(t, StatusBlocked, s2) + + // test Scan from nil - should get first value from StatusValues() + err = s2.Scan(nil) + require.NoError(t, err) + assert.Equal(t, StatusValues()[0], s2) + + // test invalid value + err = s2.Scan(123) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid status value") + }) + + t.Run("sqlite", func(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer db.Close() + + // create table with status column + _, err = db.Exec(`CREATE TABLE test_status (id INTEGER PRIMARY KEY, status TEXT)`) + require.NoError(t, err) + + // insert different status values + statuses := []Status{StatusActive, StatusInactive, StatusBlocked} + for i, s := range statuses { + _, err = db.Exec(`INSERT INTO test_status (id, status) VALUES (?, ?)`, i+1, s) + require.NoError(t, err) + } + + // insert nil status + _, err = db.Exec(`INSERT INTO test_status (id, status) VALUES (?, ?)`, 4, nil) + require.NoError(t, err) + + // read and verify each status + for i, expected := range statuses { + var s Status + err = db.QueryRow(`SELECT status FROM test_status WHERE id = ?`, i+1).Scan(&s) + require.NoError(t, err) + assert.Equal(t, expected, s) + } + + // verify nil status gets first value + var s Status + err = db.QueryRow(`SELECT status FROM test_status WHERE id = 4`).Scan(&s) + require.NoError(t, err) + assert.Equal(t, StatusValues()[0], s) + }) + t.Run("invalid", func(t *testing.T) { var d struct { Status Status `json:"status"` diff --git a/go.work b/go.work new file mode 100644 index 0000000..072e85c --- /dev/null +++ b/go.work @@ -0,0 +1,6 @@ +go 1.24 + +use ( + . + ./_examples/status +) \ No newline at end of file diff --git a/go.work.sum b/go.work.sum new file mode 100644 index 0000000..b775ded --- /dev/null +++ b/go.work.sum @@ -0,0 +1,18 @@ +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= +lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= +modernc.org/cc/v3 v3.41.0 h1:QoR1Sn3YWlmA1T4vLaKZfawdVtSiGx8H+cEojbC7v1Q= +modernc.org/cc/v3 v3.41.0/go.mod h1:Ni4zjJYJ04CDOhG7dn640WGfwBzfE0ecX8TyMB0Fv0Y= +modernc.org/ccgo/v3 v3.17.0 h1:o3OmOqx4/OFnl4Vm3G8Bgmqxnvxnh0nbxeT5p/dWChA= +modernc.org/ccgo/v3 v3.17.0/go.mod h1:Sg3fwVpmLvCUTaqEUjiBDAvshIaKDB0RXaf+zgqFu8I= diff --git a/internal/generator/generator.go b/internal/generator/generator.go index 1ba0b64..0db3b26 100644 --- a/internal/generator/generator.go +++ b/internal/generator/generator.go @@ -122,10 +122,11 @@ func (g *Generator) parseFile(file *ast.File) { } // Generate creates the enum code file. it takes the const values found in Parse and creates -// a new type with json, bson and text marshaling support. the generated code includes: +// a new type with json, sql and text marshaling support. the generated code includes: // - exported type with private name and value fields (e.g., Status{name: "active", value: 1}) // - string representation (String method) // - text marshaling (MarshalText/UnmarshalText methods) +// - sql marshaling (Value/Scan methods for driver.Valuer and sql.Scanner) // - parsing functions (Parse/Must variants) // - exported const values (e.g., StatusActive) // - helper functions to get all values and names @@ -246,6 +247,7 @@ var enumTemplate = template.Must(template.New("enum").Funcs(funcMap).Parse(`// C package {{.Package}} import ( + "database/sql/driver" "fmt" ) @@ -269,6 +271,36 @@ func (e *{{.Type | title}}) UnmarshalText(text []byte) error { return err } +// Value implements the driver.Valuer interface +func (e {{.Type | title}}) Value() (driver.Value, error) { + return e.name, nil +} + +// Scan implements the sql.Scanner interface +func (e *{{.Type | title}}) Scan(value interface{}) error { + if value == nil { + *e = {{.Type | title}}Values()[0] + return nil + } + + str, ok := value.(string) + if !ok { + if b, ok := value.([]byte); ok { + str = string(b) + } else { + return fmt.Errorf("invalid {{.Type}} value: %v", value) + } + } + + val, err := Parse{{.Type | title}}(str) + if err != nil { + return err + } + + *e = val + return nil +} + // Parse{{.Type | title}} converts string to {{.Type}} enum value func Parse{{.Type | title}}(v string) ({{.Type | title}}, error) { {{if .LowerCase}} diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index e9be765..c8c7d18 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -1,15 +1,19 @@ package generator import ( + "go/parser" + "go/token" "os" "path/filepath" "testing" + "text/template" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGenerator(t *testing.T) { + t.Run("validation", func(t *testing.T) { _, err := New("", "") require.Error(t, err, "empty type name should fail") @@ -20,6 +24,52 @@ func TestGenerator(t *testing.T) { gen, err := New("status", "") require.NoError(t, err) assert.NotNil(t, gen) + + // check if generated code is valid Go code + tmpDir := t.TempDir() + gen, err = New("status", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + require.NoError(t, err) + + // try to parse generated code + fset := token.NewFileSet() + genFile := filepath.Join(tmpDir, "status_enum.go") + _, err = parser.ParseFile(fset, genFile, nil, parser.AllErrors) + require.NoError(t, err, "generated code should be valid Go code") + + // validate default values correctness + content, err := os.ReadFile(genFile) + require.NoError(t, err) + + // check required imports + assert.Contains(t, string(content), `"database/sql/driver"`) + assert.Contains(t, string(content), `"fmt"`) + + // check required type definition + assert.Contains(t, string(content), "type Status struct {") + assert.Contains(t, string(content), "name string") + assert.Contains(t, string(content), "value int") + + // check all required methods are present + methods := []string{ + "String() string", + "MarshalText() ([]byte, error)", + "UnmarshalText(text []byte) error", + "Value() (driver.Value, error)", + "Scan(value interface{}) error", + "ParseStatus(v string) (Status, error)", + "MustStatus(v string) Status", + "StatusValues() []Status", + "StatusNames() []string", + } + for _, method := range methods { + assert.Contains(t, string(content), method, "method %s should be present", method) + } }) t.Run("parse and generate", func(t *testing.T) { @@ -48,6 +98,61 @@ func TestGenerator(t *testing.T) { assert.Contains(t, string(content), "StatusBlocked") }) + t.Run("sql support", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("status", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "status_enum.go")) + require.NoError(t, err) + + // verify sql interface implementations are present + assert.Contains(t, string(content), "func (e Status) Value() (driver.Value, error)") + assert.Contains(t, string(content), "func (e *Status) Scan(value interface{}) error") + + // verify sql imports + assert.Contains(t, string(content), `"database/sql/driver"`) + + // verify nil handling + assert.Contains(t, string(content), "if value == nil {") + assert.Contains(t, string(content), "StatusValues()[0]") + + // verify []byte support + assert.Contains(t, string(content), "if b, ok := value.([]byte)") + }) + + t.Run("json support", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("status", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "status_enum.go")) + require.NoError(t, err) + + // verify text marshaling interface implementations are present (used by json) + assert.Contains(t, string(content), "func (e Status) MarshalText() ([]byte, error)") + assert.Contains(t, string(content), "func (e *Status) UnmarshalText(text []byte) error") + + // verify proper error handling in unmarshal + assert.Contains(t, string(content), "invalid status value: %v") + assert.Contains(t, string(content), "ParseStatus(string(text))") + + // verify string conversion in marshal + assert.Contains(t, string(content), "return []byte(e.name), nil") + }) + t.Run("missing type", func(t *testing.T) { gen, err := New("nonexistent", "") require.NoError(t, err) @@ -55,6 +160,37 @@ func TestGenerator(t *testing.T) { err = gen.Parse("../testdata") assert.Error(t, err) }) + + t.Run("invalid package", func(t *testing.T) { + tmpDir := t.TempDir() + err := os.WriteFile(filepath.Join(tmpDir, "invalid.go"), []byte(`invalid go file`), 0o600) + require.NoError(t, err) + + gen, err := New("status", tmpDir) + require.NoError(t, err) + + err = gen.Parse(tmpDir) + assert.Error(t, err) + }) + + t.Run("non-existent directory", func(t *testing.T) { + gen, err := New("status", "") + require.NoError(t, err) + + err = gen.Parse("non-existent-dir") + assert.Error(t, err) + }) + + t.Run("invalid output directory", func(t *testing.T) { + gen, err := New("status", "/non-existent-dir") + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + assert.Error(t, err) + }) } func TestGeneratorValues(t *testing.T) { @@ -147,3 +283,95 @@ func TestGeneratorLowerCase(t *testing.T) { assert.Contains(t, string(content), "strings.ToLower") }) } + +func TestGeneratorEdgeCases(t *testing.T) { + t.Run("invalid template", func(t *testing.T) { + // Create a generator with a broken template that will fail to execute + gen, err := New("status", "") + require.NoError(t, err) + + // Override template with invalid one + origTmpl := enumTemplate + defer func() { enumTemplate = origTmpl }() + enumTemplate = template.Must(template.New("broken").Parse("{{.Unknown}}")) // will fail on execution + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to execute template") + }) + + t.Run("format error", func(t *testing.T) { + gen, err := New("status", "") + require.NoError(t, err) + + // Override template to generate invalid Go code + origTmpl := enumTemplate + defer func() { enumTemplate = origTmpl }() + enumTemplate = template.Must(template.New("invalid").Parse("invalid go code")) + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to format source") + }) + + t.Run("invalid identifier", func(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"empty", "", false}, + {"starts with number", "123abc", false}, + {"valid", "abc123", true}, + {"valid with underscore", "abc_123", true}, + {"starts with underscore", "_abc123", true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, isValidGoIdentifier(tc.input)) + }) + } + }) +} + +func TestParseSpecialCases(t *testing.T) { + t.Run("empty const block", func(t *testing.T) { + tmpDir := t.TempDir() + err := os.WriteFile(filepath.Join(tmpDir, "empty.go"), []byte(` +package test +const ( +) +`), 0o644) + require.NoError(t, err) + + gen, err := New("status", "") + require.NoError(t, err) + + err = gen.Parse(tmpDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "no const values found for type status") + }) + + t.Run("const without values", func(t *testing.T) { + tmpDir := t.TempDir() + err := os.WriteFile(filepath.Join(tmpDir, "no_values.go"), []byte(` +package test +const name string +`), 0o644) + require.NoError(t, err) + + gen, err := New("status", "") + require.NoError(t, err) + + err = gen.Parse(tmpDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "no const values found for type status") + }) +} diff --git a/main.go b/main.go index d60054a..db66f39 100644 --- a/main.go +++ b/main.go @@ -1,4 +1,4 @@ -// Package main provides command line tool to generate enum code from the type definition. +// Package main provides entry point for enum generator package main import ( @@ -11,6 +11,9 @@ import ( var version = "dev" +// allow mocking os.Exit in tests +var osExit = os.Exit + func main() { typeFlag := flag.String("type", "", "type name (must be lowercase)") pathFlag := flag.String("path", "", "output directory path (default: same as source)") @@ -21,30 +24,35 @@ func main() { if *helpFlag { showUsage() - os.Exit(0) + osExit(0) + return } if *versionFlag { fmt.Printf("enum generator %s\n", version) - os.Exit(0) + osExit(0) + return } gen, err := generator.New(*typeFlag, *pathFlag) if err != nil { fmt.Printf("%v\n", err) showUsage() - os.Exit(1) + osExit(1) + return } gen.SetLowerCase(*lowerFlag) if err := gen.Parse("."); err != nil { fmt.Printf("%v\n", err) - os.Exit(1) + osExit(1) + return } if err := gen.Generate(); err != nil { fmt.Printf("%v\n", err) - os.Exit(1) + osExit(1) + return } } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..015d843 --- /dev/null +++ b/main_test.go @@ -0,0 +1,169 @@ +package main + +import ( + "flag" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// save original os.Exit and restore after test +var origExit = osExit + +func TestMain(m *testing.M) { + // remove all mocks after all tests + defer func() { osExit = origExit }() + m.Run() +} + +func TestIntegration(t *testing.T) { + // Reset flags between runs to avoid "flag redefined" error + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + + t.Run("generate enum", func(t *testing.T) { + // save original args and restore after test + origArgs := os.Args + origWd, err := os.Getwd() + require.NoError(t, err) + defer func() { + os.Args = origArgs + require.NoError(t, os.Chdir(origWd)) + }() + + tmpDir := t.TempDir() + + // copy testdata to tmp + err = os.WriteFile(filepath.Join(tmpDir, "status.go"), []byte(` +package test +type status uint8 +const ( + statusUnknown status = iota + statusActive + statusInactive +) +`), 0o644) + require.NoError(t, err) + + // change working directory to temp dir + require.NoError(t, os.Chdir(tmpDir)) + + // no exit should happen here + var exitCode int + osExit = func(code int) { exitCode = code } + + // set args and run main + os.Args = []string{"app", "-type", "status"} + main() + + assert.Equal(t, 0, exitCode, "unexpected os.Exit call") + + // verify generated file + content, err := os.ReadFile(filepath.Join(tmpDir, "status_enum.go")) + require.NoError(t, err) + assert.Contains(t, string(content), "type Status struct") + assert.Contains(t, string(content), "StatusActive") + assert.Contains(t, string(content), "StatusInactive") + }) + + t.Run("lower case", func(t *testing.T) { + // Reset flags for this run + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + + origArgs := os.Args + origWd, err := os.Getwd() + require.NoError(t, err) + defer func() { + os.Args = origArgs + require.NoError(t, os.Chdir(origWd)) + }() + + tmpDir := t.TempDir() + err = os.WriteFile(filepath.Join(tmpDir, "status.go"), []byte(` +package test +type status uint8 +const ( + statusUnknown status = iota + statusActive +) +`), 0o644) + require.NoError(t, err) + + // change working directory to temp dir + require.NoError(t, os.Chdir(tmpDir)) + + var exitCode int + osExit = func(code int) { exitCode = code } + + os.Args = []string{"app", "-type", "status", "-lower"} + main() + + assert.Equal(t, 0, exitCode, "unexpected os.Exit call") + + content, err := os.ReadFile(filepath.Join(tmpDir, "status_enum.go")) + require.NoError(t, err) + assert.Contains(t, string(content), `name: "active"`) + }) + + t.Run("version", func(t *testing.T) { + // Reset flags for this run + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + + origArgs := os.Args + defer func() { os.Args = origArgs }() + + var exitCode int + osExit = func(code int) { exitCode = code } + + os.Args = []string{"app", "-version"} + main() + assert.Equal(t, 0, exitCode) + }) + + t.Run("help", func(t *testing.T) { + // Reset flags for this run + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + + origArgs := os.Args + defer func() { os.Args = origArgs }() + + var exitCode int + osExit = func(code int) { exitCode = code } + + os.Args = []string{"app", "-help"} + main() + assert.Equal(t, 0, exitCode) + }) + + t.Run("missing type", func(t *testing.T) { + // Reset flags for this run + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + + origArgs := os.Args + defer func() { os.Args = origArgs }() + + var exitCode int + osExit = func(code int) { exitCode = code } + + os.Args = []string{"app"} + main() + assert.Equal(t, 1, exitCode) + }) + + t.Run("uppercase type", func(t *testing.T) { + // Reset flags for this run + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + + origArgs := os.Args + defer func() { os.Args = origArgs }() + + var exitCode int + osExit = func(code int) { exitCode = code } + + os.Args = []string{"app", "-type", "Status"} + main() + assert.Equal(t, 1, exitCode) + }) +}