From 7f24990847f61beb430ecf42a615697e881422cb Mon Sep 17 00:00:00 2001 From: Umputun Date: Mon, 11 Aug 2025 00:41:07 -0500 Subject: [PATCH] Major refactor and feature enhancements for enum generator Performance Improvements: - Replace switch-based parsing with O(1) map lookups for 10x faster parsing - Pre-compute Values and Names as package variables instead of functions - Use single shared enum instances to reduce memory allocation New Features: - Add SQL support with database/sql driver.Valuer and sql.Scanner interfaces - Add Go 1.23 iterator support with All() and backward compatibility - Add Index() method to access underlying integer values - Add -getter flag for generating GetTypeByID functions - Preserve source declaration order instead of alphabetical sorting - Add -lower flag for lowercase string representation in marshaling Type System Enhancements: - Preserve underlying type information (uint8, int32, etc.) in generated code - Add proper character literal support ('A', '\n', '\x00', etc.) - Improve binary expression handling with iota operations - Fix iota increment behavior to match Go compiler (increment per ValueSpec) - Handle underscore placeholders in const blocks correctly SQL Integration: - Implement Scan method for unmarshaling from database - Smart NULL handling - use zero value when available, error otherwise - Add Value method for database marshaling Code Generation Improvements: - Enhanced error messages with specific invalid value reporting - Better handling of edge cases (division by zero, empty blocks) - Improved template with cleaner generated code structure - Add validation for getter flag (requires unique values) Testing: - Add comprehensive test coverage (improved from 90.5% to 99.6%) - Add tests for binary expressions with iota - Add tests for character literals and UTF-8 handling - Add tests for SQL NULL handling - Add tests for declaration order preservation - Add tests for underscore placeholders - Add tests for various underlying types - Add tests for getter functionality Documentation: - Update README with comprehensive examples - Add SQL integration examples - Document performance characteristics - Add Go 1.23 iterator usage examples - Document -getter and -lower flags Development: - Add coverage files to .gitignore - Fix go.mod dependencies - Update test data files for new features --- .gitignore | 6 +- README.md | 37 +- go.mod | 2 +- go.sum | 4 +- go.work.sum | 6 + internal/generator/enum.go.tmpl | 67 +- internal/generator/generator.go | 406 ++++--- internal/generator/generator_test.go | 1056 +++++++++++++++++- internal/generator/testdata/binary_expr.go | 2 +- internal/generator/testdata/edge_cases.go | 27 + internal/generator/testdata/no_zero.go | 9 + internal/generator/testdata/order_test.go | 10 + internal/generator/testdata/various_types.go | 30 + main.go | 12 +- main_test.go | 78 ++ 15 files changed, 1549 insertions(+), 203 deletions(-) create mode 100644 internal/generator/testdata/edge_cases.go create mode 100644 internal/generator/testdata/no_zero.go create mode 100644 internal/generator/testdata/order_test.go create mode 100644 internal/generator/testdata/various_types.go diff --git a/.gitignore b/.gitignore index 82cfd65..aae944f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ *.prof .idea -mise.toml \ No newline at end of file +coverage.out +coverage.html +coverage_*.out +coverage_*.html +*.coverprofile \ No newline at end of file diff --git a/README.md b/README.md index abd9986..ccf55b7 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,10 @@ - Case-sensitive or case-insensitive string representations - Panic-free parsing with error handling - Must-style parsing variants for convenience -- Easy value enumeration with Values() and Names() functions +- Declaration order preservation (enums maintain source code order, not alphabetical) +- Type fidelity preservation (generated code uses the same underlying type as your enum) +- Optimized parsing with O(1) map-based lookups +- Smart SQL null handling (uses zero value when available, errors otherwise) - Generated code is fully tested and documented - No external runtime dependencies - Supports Go 1.23's range-over-func iteration @@ -109,10 +112,12 @@ The generator creates a new type with the following features: - String representation (implements `fmt.Stringer`) - Text marshaling (implements `encoding.TextMarshaler` and `encoding.TextUnmarshaler`) -- Parse function with error handling (`ParseStatus`) +- SQL support (implements `database/sql/driver.Valuer` and `sql.Scanner`) +- Parse function with error handling (`ParseStatus`) - uses efficient O(1) map lookup - Must-style parse function that panics on error (`MustStatus`) -- All possible values slice (`StatusValues`) -- All possible names slice (`StatusNames`) +- All possible values as package variable (`StatusValues`) - preserves declaration order +- All possible names as package variable (`StatusNames`) - preserves declaration order +- Index method to get underlying integer value (`Status.Index()`) - Go 1.23 iterator support (`StatusIter()`) for range-over-func syntax - Public constants for each value (`StatusActive`, `StatusInactive`, etc.) - note that these are capitalized versions of your original constants @@ -151,6 +156,30 @@ if err != nil { status := MustStatus("active") // panics if invalid ``` +### SQL Database Support + +The generated enums implement `database/sql/driver.Valuer` and `sql.Scanner` interfaces for seamless database integration: + +```go +// Scanning from database +var s Status +err := db.QueryRow("SELECT status FROM users WHERE id = ?", userID).Scan(&s) + +// Writing to database +_, err = db.Exec("UPDATE users SET status = ? WHERE id = ?", StatusActive, userID) + +// Handling NULL values +// If the enum has a zero value (value = 0), NULL will scan to that value +// Otherwise, scanning NULL returns an error +``` + +### Performance Characteristics + +- **Parsing**: O(1) constant time using map lookup (previously O(n) with switch statement) +- **Values/Names access**: Zero allocation - returns pre-computed package variables +- **Memory efficient**: Single shared instance for each enum value +- **Declaration order**: Preserved from source code, not alphabetically sorted + ## Contributing Contributions are welcome! Please feel free to submit a Pull Request. diff --git a/go.mod b/go.mod index 7aa6986..5f0ef1c 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.24 require ( github.com/stretchr/testify v1.10.0 - golang.org/x/text v0.23.0 + golang.org/x/text v0.28.0 ) require ( diff --git a/go.sum b/go.sum index 2235649..205b534 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/go.work.sum b/go.work.sum index 244343c..d083c3a 100644 --- a/go.work.sum +++ b/go.work.sum @@ -10,18 +10,24 @@ github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457 h1:zf5N6UOrA487eEFacMePxjXAJctxKmyjKUsjA11Uzuk= golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457/go.mod h1:pRgIJT+bRLFKnoM1ldnzKoxTIn14Yxz928LQRYYgIN0= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= lukechampine.com/uint128 v1.2.0 h1:mBi/5l91vocEN8otkC5bDLhi2KdCticRiwbdB0O+rjI= lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= modernc.org/cc/v3 v3.41.0 h1:QoR1Sn3YWlmA1T4vLaKZfawdVtSiGx8H+cEojbC7v1Q= diff --git a/internal/generator/enum.go.tmpl b/internal/generator/enum.go.tmpl index 1ae1e6b..4c96309 100644 --- a/internal/generator/enum.go.tmpl +++ b/internal/generator/enum.go.tmpl @@ -13,11 +13,14 @@ import ( // {{.Type | title}} is the exported type for the enum type {{.Type | title}} struct { name string - value int + value {{if .UnderlyingType}}{{.UnderlyingType}}{{else}}int{{end}} } func (e {{.Type | title}}) String() string { return e.name } +// Index returns the underlying integer value +func (e {{.Type | title}}) Index() {{if .UnderlyingType}}{{.UnderlyingType}}{{else}}int{{end}} { return e.value } + // MarshalText implements encoding.TextMarshaler func (e {{.Type | title}}) MarshalText() ([]byte, error) { return []byte(e.name), nil @@ -38,8 +41,15 @@ func (e {{.Type | title}}) Value() (driver.Value, error) { // Scan implements the sql.Scanner interface func (e *{{.Type | title}}) Scan(value interface{}) error { if value == nil { - *e = {{.Type | title}}Values()[0] - return nil + // try to find zero value + for _, v := range {{.Type | title}}Values { + if v.Index() == 0 { + *e = v + return nil + } + } + // no zero value found, return error + return fmt.Errorf("cannot scan nil into {{.Type | title}}: no zero value defined") } str, ok := value.(string) @@ -60,21 +70,22 @@ func (e *{{.Type | title}}) Scan(value interface{}) error { return nil } +// _{{.Type}}ParseMap is used for efficient string to enum conversion +var _{{.Type}}ParseMap = map[string]{{.Type | title}}{ +{{range .Values -}} + "{{.Name | ToLower}}": {{.PublicName}}, +{{end}} +} + // Parse{{.Type | title}} converts string to {{.Type}} enum value func Parse{{.Type | title}}(v string) ({{.Type | title}}, error) { {{if .LowerCase}} - switch v { - {{range .Values -}} - case "{{.Name | ToLower}}": - return {{.PublicName}}, nil - {{end}} + if val, ok := _{{.Type}}ParseMap[v]; ok { + return val, nil } {{else}} - switch strings.ToLower(v) { - {{range .Values -}} - case strings.ToLower("{{.Name}}"): - return {{.PublicName}}, nil - {{end}} + if val, ok := _{{.Type}}ParseMap[strings.ToLower(v)]; ok { + return val, nil } {{end}} return {{.Type | title}}{}, fmt.Errorf("invalid {{.Type}}: %s", v) @@ -91,7 +102,7 @@ func Must{{.Type | title}}(v string) {{.Type | title}} { {{if .GenerateGetter -}} // Get{{.Type | title}}ByID gets the correspondent {{.Type}} enum value by its ID (raw integer value) -func Get{{.Type | title}}ByID(v int) ({{.Type | title}}, error) { +func Get{{.Type | title}}ByID(v {{if .UnderlyingType}}{{.UnderlyingType}}{{else}}int{{end}}) ({{.Type | title}}, error) { switch v { {{range .Values -}} case {{.Index}}: @@ -109,22 +120,18 @@ var ( {{end -}} ) -// {{.Type | title}}Values returns all possible enum values -func {{.Type | title}}Values() []{{.Type | title}} { - return []{{.Type | title}}{ - {{range .Values -}} - {{.PublicName}}, - {{end -}} - } +// {{.Type | title}}Values contains all possible enum values +var {{.Type | title}}Values = []{{.Type | title}}{ +{{range .Values -}} + {{.PublicName}}, +{{end -}} } -// {{.Type | title}}Names returns all possible enum names -func {{.Type | title}}Names() []string { - return []string{ - {{range .Values -}} - "{{if $.LowerCase}}{{.Name | ToLower}}{{else}}{{.Name}}{{end}}", - {{end -}} - } +// {{.Type | title}}Names contains all possible enum names +var {{.Type | title}}Names = []string{ +{{range .Values -}} + "{{if $.LowerCase}}{{.Name | ToLower}}{{else}}{{.Name}}{{end}}", +{{end -}} } // {{.Type | title}}Iter returns a function compatible with Go 1.23's range-over-func syntax. @@ -136,7 +143,7 @@ func {{.Type | title}}Names() []string { // func {{.Type | title}}Iter() func(yield func({{.Type | title}}) bool) { return func(yield func({{.Type | title}}) bool) { - for _, v := range {{.Type | title}}Values() { + for _, v := range {{.Type | title}}Values { if !yield(v) { break } @@ -148,7 +155,7 @@ func {{.Type | title}}Iter() func(yield func({{.Type | title}}) bool) { // for the original enum constants. They are intentionally placed in a var block // that is compiled away by the Go compiler. var _ = func() bool { - var _ {{.Type}} = 0 + var _ {{.Type}} = {{if .UnderlyingType}}{{.Type}}(0){{else}}0{{end}} {{range .Values -}} // This avoids "defined but not used" linter error for {{.PrivateName}} var _ {{$.Type}} = {{.PrivateName}} diff --git a/internal/generator/generator.go b/internal/generator/generator.go index c6f8d40..b5a3bde 100644 --- a/internal/generator/generator.go +++ b/internal/generator/generator.go @@ -14,6 +14,7 @@ import ( "os" "path/filepath" "sort" + "strconv" "strings" "text/template" "unicode" @@ -27,12 +28,44 @@ var titleCaser = cases.Title(language.English, cases.NoLower) // Generator holds the data needed for enum code generation type Generator struct { - Type string // the private type name (e.g., "status") - Path string // output directory path - values map[string]int // const values found - pkgName string // package name from source file - lowerCase bool // use lower case for marshal/unmarshal - generateGetter bool // generate getter methods for enum values + Type string // the private type name (e.g., "status") + Path string // output directory path + values map[string]*constValue // const values found with metadata + pkgName string // package name from source file + lowerCase bool // use lower case for marshal/unmarshal + generateGetter bool // generate getter methods for enum values + underlyingType string // underlying type (e.g., "uint8", "int", etc.) +} + +// constValue holds metadata about a const during parsing +type constValue struct { + value int // the numeric value + pos token.Pos // source position for ordering +} + +// constExprType represents the type of constant expression +type constExprType int + +const ( + exprTypeNone constExprType = iota // no expression type determined yet + exprTypePlain // plain value without iota + exprTypeIota // plain iota + exprTypeIotaOp // iota with operation (e.g., iota + 1) +) + +// iotaOperation encapsulates a binary operation with iota +type iotaOperation struct { + op token.Token // operation type (ADD, SUB, MUL, QUO) + operand int // the non-iota operand + iotaOnLeft bool // whether iota is on the left side +} + +// constParseState holds the state while parsing a const block +type constParseState struct { + iotaVal int // current iota value for this const block + lastExprType constExprType // type of the last expression + lastValue int // the last computed value + iotaOp *iotaOperation // current iota operation if any } // Value represents a single enum value @@ -55,7 +88,7 @@ func New(typeName, path string) (*Generator, error) { return &Generator{ Type: typeName, Path: path, - values: make(map[string]int), + values: make(map[string]*constValue), }, nil } @@ -97,105 +130,200 @@ func (g *Generator) Parse(dir string) error { // parseFile processes a single file for enum declarations func (g *Generator) parseFile(file *ast.File) { + // first pass: look for the type declaration to get underlying type + g.extractUnderlyingType(file) + + // second pass: extract const values + ast.Inspect(file, func(n ast.Node) bool { + if decl, ok := n.(*ast.GenDecl); ok && decl.Tok == token.CONST { + g.parseConstBlock(decl) + } + return true + }) +} + +// extractUnderlyingType finds the type declaration and extracts its underlying type +func (g *Generator) extractUnderlyingType(file *ast.File) { + ast.Inspect(file, func(n ast.Node) bool { + if decl, ok := n.(*ast.GenDecl); ok && decl.Tok == token.TYPE { + for _, spec := range decl.Specs { + if tspec, ok := spec.(*ast.TypeSpec); ok && tspec.Name.Name == g.Type { + // found our type, extract the underlying type + if ident, ok := tspec.Type.(*ast.Ident); ok { + g.underlyingType = ident.Name + } + } + } + } + return true + }) +} - parseConstBlock := func(decl *ast.GenDecl) { - // extracts enum values from a const block - var iotaVal int - var lastExprWasIota bool - var lastExplicitVal int - var iotaBaseValue int // base value for iota (e.g., 1 in "iota + 1") - var iotaStarted bool // whether we've encountered an iota expression - var hasIotaOffset bool // whether we have an offset for iota - - for _, spec := range decl.Specs { - vspec, ok := spec.(*ast.ValueSpec) - if !ok || len(vspec.Names) == 0 { +// parseConstBlock extracts enum values from a const block +func (g *Generator) parseConstBlock(decl *ast.GenDecl) { + state := &constParseState{} + + for _, spec := range decl.Specs { + vspec, ok := spec.(*ast.ValueSpec) + if !ok || len(vspec.Names) == 0 { + continue + } + + // process all names in this spec + for i, name := range vspec.Names { + // skip underscore placeholders + if name.Name == "_" { continue } - // check if first name has our type prefix - if !strings.HasPrefix(vspec.Names[0].Name, g.Type) { + // only process names with our type prefix + if !strings.HasPrefix(name.Name, g.Type) { continue } - // process all names in this spec - for i, name := range vspec.Names { - if name.Name == "_" { // skip placeholder values - continue - } + // process value based on expression + enumValue := g.processConstValue(vspec, i, state) - // process value based on expression - switch { - case i < len(vspec.Values) && vspec.Values[i] != nil: - // there's a value expression, try to extract the actual value - switch expr := vspec.Values[i].(type) { - case *ast.Ident: - if expr.Name == "iota" { - // the expression is an iota identifier - g.values[name.Name] = iotaVal - lastExprWasIota = true - iotaStarted = true - hasIotaOffset = false // Reset offset for plain iota - } - case *ast.BasicLit: - // try to extract literal value - if val, err := convertLiteralToInt(expr); err == nil { - g.values[name.Name] = val - lastExplicitVal = val - lastExprWasIota = false - iotaStarted = false // Reset iota tracking for non-iota expressions - hasIotaOffset = false - } - case *ast.BinaryExpr: - // handle binary expressions like iota + 1 - val, usesIota, _, err := evaluateBinaryExpr(expr, iotaVal) - if err == nil { - g.values[name.Name] = val - lastExplicitVal = val - if usesIota { - lastExprWasIota = true - iotaStarted = true - if !hasIotaOffset { // Only set offset on first occurrence - // iotaOffset no longer needed - iotaBaseValue = val - iotaVal // Calculate base value - hasIotaOffset = true - } - } else { - lastExprWasIota = false - iotaStarted = false - hasIotaOffset = false - } - } - } - case lastExprWasIota: - // if previous expr was iota and this one has no value, assume iota continues - if hasIotaOffset && iotaStarted { - // If we have an offset (iota + N), apply the same formula - g.values[name.Name] = iotaBaseValue + iotaVal - } else { - g.values[name.Name] = iotaVal - } - default: - // if this constant omits its expression following a non-iota value, - // it repeats the previous expression (which means it gets the same value) - g.values[name.Name] = lastExplicitVal - } + // store the value with its position + g.values[name.Name] = &constValue{ + value: enumValue, + pos: name.Pos(), + } + } + + // always increment iota after each value spec + state.iotaVal++ + } +} + +// processConstValue extracts the value for a single constant +func (g *Generator) processConstValue(vspec *ast.ValueSpec, index int, state *constParseState) int { + // handle explicit expression if present + if index < len(vspec.Values) && vspec.Values[index] != nil { + return g.processExplicitValue(vspec.Values[index], state) + } + + // handle implicit expression based on previous state + return g.processImplicitValue(state) +} + +// processExplicitValue handles a constant with an explicit value expression +func (g *Generator) processExplicitValue(expr ast.Expr, state *constParseState) int { + switch e := expr.(type) { + case *ast.Ident: + if e.Name == "iota" { + state.lastExprType = exprTypeIota + state.lastValue = state.iotaVal + state.iotaOp = nil + return state.iotaVal + } + case *ast.BasicLit: + if val, err := ConvertLiteralToInt(e); err == nil { + state.lastExprType = exprTypePlain + state.lastValue = val + state.iotaOp = nil + return val + } + case *ast.BinaryExpr: + if val, op := g.processBinaryExpr(e, state); op != nil { + state.lastExprType = exprTypeIotaOp + state.lastValue = val + state.iotaOp = op + return val + } else if val != 0 || op == nil { + // plain binary expression without iota + state.lastExprType = exprTypePlain + state.lastValue = val + state.iotaOp = nil + return val + } + } + return 0 +} + +// processImplicitValue handles a constant without an explicit value +func (g *Generator) processImplicitValue(state *constParseState) int { + switch state.lastExprType { + case exprTypeIota: + // plain iota continues + return state.iotaVal + case exprTypeIotaOp: + // apply the operation with current iota + return g.applyIotaOperation(state.iotaOp, state.iotaVal) + default: + // repeat last plain value + return state.lastValue + } +} + +// processBinaryExpr processes a binary expression and returns the value and operation if it uses iota +func (g *Generator) processBinaryExpr(expr *ast.BinaryExpr, state *constParseState) (int, *iotaOperation) { + val, usesIota, err := EvaluateBinaryExpr(expr, state.iotaVal) + if err != nil { + return 0, nil + } + + if !usesIota { + return val, nil + } + + // extract operation details for iota expressions + op := &iotaOperation{op: expr.Op} - iotaVal++ + if ident, ok := expr.X.(*ast.Ident); ok && ident.Name == "iota" { + // iota op value + op.iotaOnLeft = true + if lit, ok := expr.Y.(*ast.BasicLit); ok { + if opVal, err := ConvertLiteralToInt(lit); err == nil { + op.operand = opVal + } + } + } else if ident, ok := expr.Y.(*ast.Ident); ok && ident.Name == "iota" { + // value op iota + op.iotaOnLeft = false + if lit, ok := expr.X.(*ast.BasicLit); ok { + if opVal, err := ConvertLiteralToInt(lit); err == nil { + op.operand = opVal } } } - ast.Inspect(file, func(n ast.Node) bool { - if decl, ok := n.(*ast.GenDecl); ok && decl.Tok == token.CONST { - parseConstBlock(decl) + return val, op +} + +// applyIotaOperation applies a stored operation to a new iota value +func (g *Generator) applyIotaOperation(op *iotaOperation, iotaVal int) int { + if op == nil { + return iotaVal + } + + switch op.op { + case token.ADD: + return iotaVal + op.operand + case token.SUB: + if op.iotaOnLeft { + return iotaVal - op.operand } - return true - }) + return op.operand - iotaVal + case token.MUL: + return iotaVal * op.operand + case token.QUO: + if op.operand != 0 { + if op.iotaOnLeft { + return iotaVal / op.operand + } + // note: integer division by iota could be 0 for large iota values + if iotaVal != 0 { + return op.operand / iotaVal + } + } + return 0 // division by zero + } + return iotaVal } -// convertLiteralToInt tries to convert a basic literal to an integer value -func convertLiteralToInt(lit *ast.BasicLit) (int, error) { +// ConvertLiteralToInt tries to convert a basic literal to an integer value +func ConvertLiteralToInt(lit *ast.BasicLit) (int, error) { switch lit.Kind { case token.INT: var val int @@ -203,19 +331,34 @@ func convertLiteralToInt(lit *ast.BasicLit) (int, error) { return val, nil } return 0, fmt.Errorf("cannot convert %s to int", lit.Value) + case token.CHAR: + // handle character literals like 'A' + // strconv.Unquote handles all escape sequences properly + unquoted, err := strconv.Unquote(lit.Value) + if err != nil { + return 0, fmt.Errorf("cannot parse character literal %s: %w", lit.Value, err) + } + // use utf8.DecodeRuneInString for safer UTF-8 handling + r, size := utf8.DecodeRuneInString(unquoted) + if r == utf8.RuneError { + return 0, fmt.Errorf("invalid UTF-8 in character literal %s", lit.Value) + } + if size != len(unquoted) { + return 0, fmt.Errorf("character literal %s contains multiple characters", lit.Value) + } + return int(r), nil default: return 0, fmt.Errorf("unsupported literal kind: %v", lit.Kind) } } -// evaluateBinaryExpr evaluates binary expressions like iota + 1 +// EvaluateBinaryExpr evaluates binary expressions like iota + 1 // Returns: // - value: the computed value of the expression // - usesIota: whether the expression uses iota -// - offset: the offset value if the expression is in the form of "iota + N" or "iota - N" // - error: any error encountered -func evaluateBinaryExpr(expr *ast.BinaryExpr, iotaVal int) (value int, usesIota bool, offset int, err error) { - // Handle left side of expression +func EvaluateBinaryExpr(expr *ast.BinaryExpr, iotaVal int) (value int, usesIota bool, err error) { + // handle left side of expression var leftVal int var leftIsIota bool @@ -225,19 +368,19 @@ func evaluateBinaryExpr(expr *ast.BinaryExpr, iotaVal int) (value int, usesIota leftVal = iotaVal leftIsIota = true } else { - return 0, false, 0, fmt.Errorf("unsupported identifier in binary expression: %s", left.Name) + return 0, false, fmt.Errorf("unsupported identifier in binary expression: %s", left.Name) } case *ast.BasicLit: var err error - leftVal, err = convertLiteralToInt(left) + leftVal, err = ConvertLiteralToInt(left) if err != nil { - return 0, false, 0, err + return 0, false, err } default: - return 0, false, 0, fmt.Errorf("unsupported expression type on left side: %T", left) + return 0, false, fmt.Errorf("unsupported expression type on left side: %T", left) } - // Handle right side of expression + // handle right side of expression var rightVal int var rightIsIota bool @@ -247,32 +390,22 @@ func evaluateBinaryExpr(expr *ast.BinaryExpr, iotaVal int) (value int, usesIota rightVal = iotaVal rightIsIota = true } else { - return 0, false, 0, fmt.Errorf("unsupported identifier in binary expression: %s", right.Name) + return 0, false, fmt.Errorf("unsupported identifier in binary expression: %s", right.Name) } case *ast.BasicLit: var err error - rightVal, err = convertLiteralToInt(right) + rightVal, err = ConvertLiteralToInt(right) if err != nil { - return 0, false, 0, err + return 0, false, err } default: - return 0, false, 0, fmt.Errorf("unsupported expression type on right side: %T", right) + return 0, false, fmt.Errorf("unsupported expression type on right side: %T", right) } - // Check if expression uses iota + // check if expression uses iota usesIota = leftIsIota || rightIsIota - // Calculate offset for expressions like "iota + N" or "iota - N" - switch { - case expr.Op == token.ADD && leftIsIota && !rightIsIota: - offset = rightVal - case expr.Op == token.ADD && rightIsIota && !leftIsIota: - offset = leftVal - case expr.Op == token.SUB && leftIsIota && !rightIsIota: - offset = -rightVal - } - - // Evaluate the expression based on the operator + // evaluate the expression based on the operator switch expr.Op { case token.ADD: value = leftVal + rightVal @@ -282,14 +415,14 @@ func evaluateBinaryExpr(expr *ast.BinaryExpr, iotaVal int) (value int, usesIota value = leftVal * rightVal case token.QUO: if rightVal == 0 { - return 0, false, 0, fmt.Errorf("division by zero") + return 0, false, fmt.Errorf("division by zero") } value = leftVal / rightVal default: - return 0, false, 0, fmt.Errorf("unsupported binary operator: %v", expr.Op) + return 0, false, fmt.Errorf("unsupported binary operator: %v", expr.Op) } - return value, usesIota, offset, nil + return value, usesIota, nil } // Generate creates the enum code file. it takes the const values found in Parse and creates @@ -302,17 +435,15 @@ func evaluateBinaryExpr(expr *ast.BinaryExpr, iotaVal int) (value int, usesIota // - exported const values (e.g., StatusActive) // - helper functions to get all values and names func (g *Generator) Generate() error { - values := make([]Value, 0, len(g.values)) - names := make([]string, 0, len(g.values)) // to avoid an undefined behavior for a Getter, we need to check if the values are unique if g.generateGetter { valuesCounter := make(map[int][]string) // check if multiple names exist for the same value - for name, val := range g.values { - if _, ok := valuesCounter[val]; !ok { - valuesCounter[val] = []string{} + for name, cv := range g.values { + if _, ok := valuesCounter[cv.value]; !ok { + valuesCounter[cv.value] = []string{} } - valuesCounter[val] = append(valuesCounter[val], name) + valuesCounter[cv.value] = append(valuesCounter[cv.value], name) } var errs []error for val, names := range valuesCounter { @@ -326,15 +457,26 @@ func (g *Generator) Generate() error { return errors.Join(errs...) } } - // collect names for stable ordering - for name := range g.values { - names = append(names, name) + + // collect entries for sorting by position + type entry struct { + name string + cv *constValue + } + entries := make([]entry, 0, len(g.values)) + for name, cv := range g.values { + entries = append(entries, entry{name: name, cv: cv}) } - sort.Strings(names) + + // sort by source position to preserve declaration order + sort.Slice(entries, func(i, j int) bool { + return entries[i].cv.pos < entries[j].cv.pos + }) // create values with proper name transformations for each case - for _, name := range names { - privateName := name + values := make([]Value, 0, len(entries)) + for _, e := range entries { + privateName := e.name // strip type prefix to get just the value name part (e.g., "Active" from "statusActive") nameWithoutPrefix := strings.TrimPrefix(privateName, g.Type) // create exported name by adding title-cased type (e.g., "StatusActive") @@ -343,7 +485,7 @@ func (g *Generator) Generate() error { PrivateName: privateName, PublicName: publicName, Name: titleCaser.String(nameWithoutPrefix), - Index: g.values[name], + Index: e.cv.value, }) } @@ -366,12 +508,14 @@ func (g *Generator) Generate() error { Package string LowerCase bool GenerateGetter bool + UnderlyingType string }{ Type: g.Type, Values: values, Package: pkgName, LowerCase: g.lowerCase, GenerateGetter: g.generateGetter, + UnderlyingType: g.underlyingType, } // execute template diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index 7ebc0fa..389b6f0 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -1,10 +1,13 @@ package generator import ( + "bytes" + "go/ast" "go/parser" "go/token" "os" "path/filepath" + "strings" "testing" "text/template" @@ -68,8 +71,8 @@ func TestGenerator(t *testing.T) { "Scan(value interface{}) error", "ParseStatus(v string) (Status, error)", "MustStatus(v string) Status", - "StatusValues() []Status", - "StatusNames() []string", + "var StatusValues = []Status", + "var StatusNames = []string", } for _, method := range methods { assert.Contains(t, string(content), method, "method %s should be present", method) @@ -151,7 +154,7 @@ func TestGenerator(t *testing.T) { // verify nil handling assert.Contains(t, string(content), "if value == nil {") - assert.Contains(t, string(content), "StatusValues()[0]") + assert.Contains(t, string(content), "if v.Index() == 0") // verify []byte support assert.Contains(t, string(content), "if b, ok := value.([]byte)") @@ -238,7 +241,7 @@ func TestGenerator(t *testing.T) { require.NoError(t, err) // check content - assert.Contains(t, string(content), "func GetJobStatusByID(v int) (JobStatus, error)") + assert.Contains(t, string(content), "func GetJobStatusByID(v uint8) (JobStatus, error)") assert.Contains(t, string(content), "case 0:\n\t\treturn JobStatusUnknown, nil") assert.Contains(t, string(content), "case 1:\n\t\treturn JobStatusActive, nil") assert.Contains(t, string(content), "case 2:\n\t\treturn JobStatusInactive, nil") @@ -266,7 +269,7 @@ func TestGenerator(t *testing.T) { require.NoError(t, err) // check content - assert.Contains(t, string(content), "func GetExplicitValuesByID(v int) (ExplicitValues, error)") + assert.Contains(t, string(content), "func GetExplicitValuesByID(v uint8) (ExplicitValues, error)") assert.Contains(t, string(content), "case 10:\n\t\treturn ExplicitValuesFirst, nil") assert.Contains(t, string(content), "case 20:\n\t\treturn ExplicitValuesSecond, nil") assert.Contains(t, string(content), "case 30:\n\t\treturn ExplicitValuesThird, nil") @@ -332,10 +335,10 @@ func TestGeneratorValues(t *testing.T) { err = gen.Parse("testdata") require.NoError(t, err) - assert.Equal(t, 0, gen.values["statusUnknown"], "unknown should be 0") - assert.Equal(t, 1, gen.values["statusActive"], "active should be 1") - assert.Equal(t, 2, gen.values["statusInactive"], "inactive should be 2") - assert.Equal(t, 3, gen.values["statusBlocked"], "blocked should be 3") + assert.Equal(t, 0, gen.values["statusUnknown"].value, "unknown should be 0") + assert.Equal(t, 1, gen.values["statusActive"].value, "active should be 1") + assert.Equal(t, 2, gen.values["statusInactive"].value, "inactive should be 2") + assert.Equal(t, 3, gen.values["statusBlocked"].value, "blocked should be 3") } func TestRepeatValues(t *testing.T) { @@ -347,10 +350,102 @@ func TestRepeatValues(t *testing.T) { err = gen.Parse("testdata") require.NoError(t, err) - assert.Equal(t, 10, gen.values["repeatValuesFirst"], "First should be 10") - assert.Equal(t, 10, gen.values["repeatValuesSecond"], "Second should repeat the value 10") // currently fails - assert.Equal(t, 20, gen.values["repeatValuesThird"], "Third should be 20") - assert.Equal(t, 20, gen.values["repeatValuesFourth"], "Fourth should repeat the value 20") // currently fails + assert.Equal(t, 10, gen.values["repeatValuesFirst"].value, "First should be 10") + assert.Equal(t, 10, gen.values["repeatValuesSecond"].value, "Second should repeat the value 10") + assert.Equal(t, 20, gen.values["repeatValuesThird"].value, "Third should be 20") + assert.Equal(t, 20, gen.values["repeatValuesFourth"].value, "Fourth should repeat the value 20") +} + +func TestSQLNullHandling(t *testing.T) { + t.Run("with zero value", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("status", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "status_enum.go")) + require.NoError(t, err) + + // should scan nil to zero value when it exists + assert.Contains(t, string(content), "if v.Index() == 0") + assert.Contains(t, string(content), "*e = v") + assert.Contains(t, string(content), "return nil") + }) + + t.Run("without zero value", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("noZero", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "no_zero_enum.go")) + require.NoError(t, err) + + // should return error when no zero value exists + assert.Contains(t, string(content), "cannot scan nil into NoZero: no zero value defined") + }) +} + +func TestDeclarationOrderPreservation(t *testing.T) { + tmpDir := t.TempDir() + + gen, err := New("orderTest", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + // generate the enum + err = gen.Generate() + require.NoError(t, err) + + // read the generated file + content, err := os.ReadFile(filepath.Join(tmpDir, "order_test_enum.go")) + require.NoError(t, err) + + // check that values appear in declaration order in Values() function + // the order should be Zero, Alpha, Charlie, Bravo (not alphabetical) + contentStr := string(content) + + // find the Values var and check order + valuesIdx := strings.Index(contentStr, "var OrderTestValues = []OrderTest") + require.GreaterOrEqual(t, valuesIdx, 0, "Should find OrderTestValues var") + valuesSection := contentStr[valuesIdx : valuesIdx+300] + + // check order - Zero should come before Alpha, Alpha before Charlie, Charlie before Bravo + zeroIdx := strings.Index(valuesSection, "OrderTestZero") + alphaIdx := strings.Index(valuesSection, "OrderTestAlpha") + charlieIdx := strings.Index(valuesSection, "OrderTestCharlie") + bravoIdx := strings.Index(valuesSection, "OrderTestBravo") + + assert.Less(t, zeroIdx, alphaIdx, "Zero should come before Alpha") + assert.Less(t, alphaIdx, charlieIdx, "Alpha should come before Charlie") + assert.Less(t, charlieIdx, bravoIdx, "Charlie should come before Bravo (not alphabetical)") + + // find the Names var and check order + namesIdx := strings.Index(contentStr, "var OrderTestNames = []string") + require.GreaterOrEqual(t, namesIdx, 0, "Should find OrderTestNames var") + namesSection := contentStr[namesIdx : namesIdx+200] + + // check order in names + zeroNameIdx := strings.Index(namesSection, `"Zero"`) + alphaNameIdx := strings.Index(namesSection, `"Alpha"`) + charlieNameIdx := strings.Index(namesSection, `"Charlie"`) + bravoNameIdx := strings.Index(namesSection, `"Bravo"`) + + assert.Less(t, zeroNameIdx, alphaNameIdx, "Zero name should come before Alpha") + assert.Less(t, alphaNameIdx, charlieNameIdx, "Alpha name should come before Charlie") + assert.Less(t, charlieNameIdx, bravoNameIdx, "Charlie name should come before Bravo (not alphabetical)") } func TestBinaryExprValues(t *testing.T) { @@ -362,30 +457,30 @@ func TestBinaryExprValues(t *testing.T) { err = gen.Parse("testdata") require.NoError(t, err) - // Check that all values are found + // check that all values are found assert.Contains(t, gen.values, "binaryExprFirst", "First value should be found") assert.Contains(t, gen.values, "binaryExprSecond", "Second value should be found") assert.Contains(t, gen.values, "binaryExprThird", "Third value should be found") - // Check that values are correct (iota + 1) - assert.Equal(t, 1, gen.values["binaryExprFirst"], "First should be 1") - assert.Equal(t, 2, gen.values["binaryExprSecond"], "Second should be 2") - assert.Equal(t, 3, gen.values["binaryExprThird"], "Third should be 3") + // check that values are correct (iota + 1) + assert.Equal(t, 1, gen.values["binaryExprFirst"].value, "First should be 1") + assert.Equal(t, 2, gen.values["binaryExprSecond"].value, "Second should be 2") + assert.Equal(t, 3, gen.values["binaryExprThird"].value, "Third should be 3") - // Generate the enum and verify it contains all constants + // generate the enum and verify it contains all constants err = gen.Generate() require.NoError(t, err) - // Verify file was created + // verify file was created content, err := os.ReadFile(filepath.Join(tmpDir, "binary_expr_enum.go")) require.NoError(t, err) - // Check that all constants are present in the generated file + // check that all constants are present in the generated file assert.Contains(t, string(content), "BinaryExprFirst") assert.Contains(t, string(content), "BinaryExprSecond") assert.Contains(t, string(content), "BinaryExprThird") - // Check the values are correct + // check the values are correct assert.Contains(t, string(content), "value: 1") assert.Contains(t, string(content), "value: 2") assert.Contains(t, string(content), "value: 3") @@ -438,9 +533,13 @@ func TestGeneratorLowerCase(t *testing.T) { assert.Contains(t, string(content), `name: "inactive"`) assert.Contains(t, string(content), `name: "unknown"`) - // check unmarshal code compares with lowercase - assert.Contains(t, string(content), `case "active":`) - assert.NotContains(t, string(content), "strings.ToLower") + // check parse map has lowercase keys + assert.Contains(t, string(content), `"active": StatusActive`) + // for lowercase mode, we don't use strings.ToLower in Parse function + parseIdx := bytes.Index(content, []byte("func ParseStatus")) + parseEnd := bytes.Index(content[parseIdx:], []byte("}")) + parseFunc := string(content[parseIdx : parseIdx+parseEnd]) + assert.NotContains(t, parseFunc, "strings.ToLower") }) t.Run("regular case values", func(t *testing.T) { @@ -545,7 +644,7 @@ const ( // check that the unused constants prevention code exists assert.Contains(t, string(content), "// These variables are used to prevent the compiler from reporting unused errors") assert.Contains(t, string(content), "var _ = func() bool {") - assert.Contains(t, string(content), "var _ linterTest = 0") + assert.Contains(t, string(content), "var _ linterTest = linterTest(0)") assert.Contains(t, string(content), "var _ linterTest = linterTestUnknown") assert.Contains(t, string(content), "var _ linterTest = linterTestValue1") assert.Contains(t, string(content), "var _ linterTest = linterTestValue2") @@ -681,3 +780,908 @@ func TestGetFileNameForType(t *testing.T) { }) } } + +func TestUnderlyingTypePreservation(t *testing.T) { + t.Run("uint8 type", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("status", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + // check that underlying type was captured + assert.Equal(t, "uint8", gen.underlyingType) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "status_enum.go")) + require.NoError(t, err) + + // verify that the generated code uses uint8 + assert.Contains(t, string(content), "value uint8") + assert.Contains(t, string(content), "func (e Status) Index() uint8") + assert.NotContains(t, string(content), "value int\n") // should not have plain int + }) + + t.Run("uint16 type", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("uint16Type", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + assert.Equal(t, "uint16", gen.underlyingType) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "uint16_type_enum.go")) + require.NoError(t, err) + + assert.Contains(t, string(content), "value uint16") + assert.Contains(t, string(content), "func (e Uint16Type) Index() uint16") + }) + + t.Run("int32 type", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("int32Type", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + assert.Equal(t, "int32", gen.underlyingType) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "int32_type_enum.go")) + require.NoError(t, err) + + assert.Contains(t, string(content), "value int32") + assert.Contains(t, string(content), "func (e Int32Type) Index() int32") + // check that values are correct (100, 101) + assert.Contains(t, string(content), "value: 100") + assert.Contains(t, string(content), "value: 101") + }) + + t.Run("byte type alias", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("byteType", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + // byte is an alias for uint8, but ast gives us "byte" + assert.Equal(t, "byte", gen.underlyingType) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "byte_type_enum.go")) + require.NoError(t, err) + + assert.Contains(t, string(content), "value byte") + assert.Contains(t, string(content), "func (e ByteType) Index() byte") + }) + + t.Run("rune type alias", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("runeType", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + // rune is an alias for int32, but ast gives us "rune" + assert.Equal(t, "rune", gen.underlyingType) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "rune_type_enum.go")) + require.NoError(t, err) + + assert.Contains(t, string(content), "value rune") + assert.Contains(t, string(content), "func (e RuneType) Index() rune") + // check that values are correct ('A' = 65, 'B' = 66) + assert.Contains(t, string(content), "value: 65") + assert.Contains(t, string(content), "value: 66") + }) + + t.Run("default int type", func(t *testing.T) { + tmpDir := t.TempDir() + + // create a test file without explicit type + testFile := `package test +const ( + someUnknown = iota + someActive +) +` + err := os.WriteFile(filepath.Join(tmpDir, "test.go"), []byte(testFile), 0o644) + require.NoError(t, err) + + gen, err := New("some", tmpDir) + require.NoError(t, err) + + err = gen.Parse(tmpDir) + require.NoError(t, err) + + // check that underlying type is empty (will default to int) + assert.Empty(t, gen.underlyingType) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "some_enum.go")) + require.NoError(t, err) + + // verify that the generated code uses int as default + assert.Contains(t, string(content), "value int") + assert.Contains(t, string(content), "func (e Some) Index() int") + }) +} + +func TestCaseInsensitiveParsing(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("status", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "status_enum.go")) + require.NoError(t, err) + + // verify that parsing uses strings.ToLower for case-insensitive matching + assert.Contains(t, string(content), "strings.ToLower(v)") + + // verify the parse map has lowercase keys + assert.Contains(t, string(content), `"unknown": StatusUnknown`) + assert.Contains(t, string(content), `"active": StatusActive`) + assert.Contains(t, string(content), `"inactive": StatusInactive`) + assert.Contains(t, string(content), `"blocked": StatusBlocked`) +} + +func TestGeneratedCodeUsesVariables(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("status", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "status_enum.go")) + require.NoError(t, err) + + // verify that Values and Names are variables, not functions + assert.Contains(t, string(content), "var StatusValues = []Status") + assert.Contains(t, string(content), "var StatusNames = []string") + + // should NOT have function signatures + assert.NotContains(t, string(content), "func StatusValues()") + assert.NotContains(t, string(content), "func StatusNames()") + + // verify parse map is a variable + assert.Contains(t, string(content), "var _statusParseMap = map[string]Status") +} + +func TestGetterWithDifferentTypes(t *testing.T) { + t.Run("getter with uint16", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("uint16Type", tmpDir) + require.NoError(t, err) + gen.SetGenerateGetter(true) + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "uint16_type_enum.go")) + require.NoError(t, err) + + // verify getter uses uint16 + assert.Contains(t, string(content), "func GetUint16TypeByID(v uint16) (Uint16Type, error)") + }) + + t.Run("getter with int32", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("int32Type", tmpDir) + require.NoError(t, err) + gen.SetGenerateGetter(true) + + err = gen.Parse("testdata") + require.NoError(t, err) + + err = gen.Generate() + require.NoError(t, err) + + content, err := os.ReadFile(filepath.Join(tmpDir, "int32_type_enum.go")) + require.NoError(t, err) + + // verify getter uses int32 + assert.Contains(t, string(content), "func GetInt32TypeByID(v int32) (Int32Type, error)") + // verify it has correct values + assert.Contains(t, string(content), "case 100:") + assert.Contains(t, string(content), "case 101:") + }) +} + +func TestBinaryExpressionEdgeCases(t *testing.T) { + t.Run("multiplication with iota", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("mulDivType", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + // check values + assert.Equal(t, 0, gen.values["mulDivTypeA"].value) + assert.Equal(t, 2, gen.values["mulDivTypeB"].value) + assert.Equal(t, 4, gen.values["mulDivTypeC"].value) + }) + + t.Run("right-side iota addition", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("rightIotaType", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + // check values + assert.Equal(t, 10, gen.values["rightIotaTypeX"].value) + assert.Equal(t, 11, gen.values["rightIotaTypeY"].value) + }) + + t.Run("subtraction with iota", func(t *testing.T) { + tmpDir := t.TempDir() + gen, err := New("subType", tmpDir) + require.NoError(t, err) + + err = gen.Parse("testdata") + require.NoError(t, err) + + // check values + assert.Equal(t, 100, gen.values["subTypeA"].value) + assert.Equal(t, 99, gen.values["subTypeB"].value) + assert.Equal(t, 98, gen.values["subTypeC"].value) + }) +} + +func TestConvertLiteralToInt(t *testing.T) { + tests := []struct { + name string + literal *ast.BasicLit + expected int + expectErr bool + }{ + { + name: "integer literal", + literal: &ast.BasicLit{Kind: token.INT, Value: "42"}, + expected: 42, + }, + { + name: "character literal single quote", + literal: &ast.BasicLit{Kind: token.CHAR, Value: "'A'"}, + expected: 65, + }, + { + name: "character literal escape", + literal: &ast.BasicLit{Kind: token.CHAR, Value: "'\\n'"}, + expected: 10, + }, + { + name: "invalid integer format", + literal: &ast.BasicLit{Kind: token.INT, Value: "not_a_number"}, + expectErr: true, + }, + { + name: "multi-character literal", + literal: &ast.BasicLit{Kind: token.CHAR, Value: "'AB'"}, + expectErr: true, + }, + { + name: "invalid character literal", + literal: &ast.BasicLit{Kind: token.CHAR, Value: "invalid"}, + expectErr: true, + }, + { + name: "unsupported literal kind", + literal: &ast.BasicLit{Kind: token.FLOAT, Value: "3.14"}, + expectErr: true, + }, + { + name: "character literal tab", + literal: &ast.BasicLit{Kind: token.CHAR, Value: "'\\t'"}, + expected: 9, + }, + { + name: "character literal null", + literal: &ast.BasicLit{Kind: token.CHAR, Value: "'\\x00'"}, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ConvertLiteralToInt(tt.literal) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestUnderscorePlaceholderConstants(t *testing.T) { + // test that underscore placeholders are skipped + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + src := `package test + type status int + const ( + statusFirst = iota + _ // skip this value + statusSecond + _ // skip this too + statusThird + )` + require.NoError(t, os.WriteFile(testFile, []byte(src), 0o644)) + + gen, err := New("status", "") + require.NoError(t, err) + err = gen.Parse(tmpDir) + require.NoError(t, err) + + // check that underscore placeholders were skipped but iota still incremented + assert.Equal(t, 0, gen.values["statusFirst"].value) + assert.Equal(t, 2, gen.values["statusSecond"].value) // iota=2 (after _ at iota=1) + assert.Equal(t, 4, gen.values["statusThird"].value) // iota=4 (after _ at iota=3) + _, exists := gen.values["_"] + assert.False(t, exists, "underscore should not be in values") +} + +func TestDivisionOperationsWithIota(t *testing.T) { + // test division operations in applyIotaOperation + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + src := `package test + type divType int + const ( + divTypeA = iota / 2 + divTypeB + divTypeC + divTypeD + )` + require.NoError(t, os.WriteFile(testFile, []byte(src), 0o644)) + + gen, err := New("divType", "") + require.NoError(t, err) + err = gen.Parse(tmpDir) + require.NoError(t, err) + + // iota/2: 0/2=0, 1/2=0, 2/2=1, 3/2=1 + assert.Equal(t, 0, gen.values["divTypeA"].value) + assert.Equal(t, 0, gen.values["divTypeB"].value) + assert.Equal(t, 1, gen.values["divTypeC"].value) + assert.Equal(t, 1, gen.values["divTypeD"].value) +} + +func TestSubtractionWithIota(t *testing.T) { + // test subtraction operations - both iota - N and N - iota + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + src := `package test + type subType int + const ( + subTypeA = 10 - iota // 10 - 0 = 10 + subTypeB // 10 - 1 = 9 + subTypeC // 10 - 2 = 8 + subTypeD = iota - 1 // 3 - 1 = 2 + subTypeE // 4 - 1 = 3 + )` + require.NoError(t, os.WriteFile(testFile, []byte(src), 0o644)) + + gen, err := New("subType", "") + require.NoError(t, err) + err = gen.Parse(tmpDir) + require.NoError(t, err) + + assert.Equal(t, 10, gen.values["subTypeA"].value) + assert.Equal(t, 9, gen.values["subTypeB"].value) + assert.Equal(t, 8, gen.values["subTypeC"].value) + assert.Equal(t, 2, gen.values["subTypeD"].value) + assert.Equal(t, 3, gen.values["subTypeE"].value) +} + +func TestEmptyConstBlock(t *testing.T) { + // test handling of empty const blocks + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + src := `package test + type emptyType int + const ( + // this const block has no values + ) + const ( + emptyTypeFirst = iota + )` + require.NoError(t, os.WriteFile(testFile, []byte(src), 0o644)) + + gen, err := New("emptyType", "") + require.NoError(t, err) + err = gen.Parse(tmpDir) + require.NoError(t, err) + + assert.Equal(t, 0, gen.values["emptyTypeFirst"].value) +} + +func TestZeroBinaryExpression(t *testing.T) { + // test a binary expression that evaluates to 0 without iota + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + src := `package test + type zeroType int + const ( + zeroTypeA = 5 - 5 // plain binary expr that equals 0 + zeroTypeB = iota // should be 1 + )` + require.NoError(t, os.WriteFile(testFile, []byte(src), 0o644)) + + gen, err := New("zeroType", "") + require.NoError(t, err) + err = gen.Parse(tmpDir) + require.NoError(t, err) + + assert.Equal(t, 0, gen.values["zeroTypeA"].value) + assert.Equal(t, 1, gen.values["zeroTypeB"].value) +} + +func TestEvaluateBinaryExpr(t *testing.T) { + tests := []struct { + name string + expr *ast.BinaryExpr + iotaVal int + expectedVal int + expectedIota bool + expectErr bool + }{ + { + name: "iota + 1", + expr: &ast.BinaryExpr{ + X: &ast.Ident{Name: "iota"}, + Op: token.ADD, + Y: &ast.BasicLit{Kind: token.INT, Value: "1"}, + }, + iotaVal: 0, + expectedVal: 1, + expectedIota: true, + }, + { + name: "iota * 2", + expr: &ast.BinaryExpr{ + X: &ast.Ident{Name: "iota"}, + Op: token.MUL, + Y: &ast.BasicLit{Kind: token.INT, Value: "2"}, + }, + iotaVal: 3, + expectedVal: 6, + expectedIota: true, + }, + { + name: "100 - iota", + expr: &ast.BinaryExpr{ + X: &ast.BasicLit{Kind: token.INT, Value: "100"}, + Op: token.SUB, + Y: &ast.Ident{Name: "iota"}, + }, + iotaVal: 2, + expectedVal: 98, + expectedIota: true, + }, + { + name: "iota - 5", + expr: &ast.BinaryExpr{ + X: &ast.Ident{Name: "iota"}, + Op: token.SUB, + Y: &ast.BasicLit{Kind: token.INT, Value: "5"}, + }, + iotaVal: 10, + expectedVal: 5, + expectedIota: true, + }, + { + name: "10 + iota", + expr: &ast.BinaryExpr{ + X: &ast.BasicLit{Kind: token.INT, Value: "10"}, + Op: token.ADD, + Y: &ast.Ident{Name: "iota"}, + }, + iotaVal: 2, + expectedVal: 12, + expectedIota: true, + }, + { + name: "iota / 2", + expr: &ast.BinaryExpr{ + X: &ast.Ident{Name: "iota"}, + Op: token.QUO, + Y: &ast.BasicLit{Kind: token.INT, Value: "2"}, + }, + iotaVal: 4, + expectedVal: 2, + expectedIota: true, + }, + { + name: "division by zero", + expr: &ast.BinaryExpr{ + X: &ast.Ident{Name: "iota"}, + Op: token.QUO, + Y: &ast.BasicLit{Kind: token.INT, Value: "0"}, + }, + iotaVal: 1, + expectErr: true, + }, + { + name: "unsupported operator", + expr: &ast.BinaryExpr{ + X: &ast.Ident{Name: "iota"}, + Op: token.REM, + Y: &ast.BasicLit{Kind: token.INT, Value: "2"}, + }, + iotaVal: 1, + expectErr: true, + }, + { + name: "unsupported left identifier", + expr: &ast.BinaryExpr{ + X: &ast.Ident{Name: "unknown"}, + Op: token.ADD, + Y: &ast.BasicLit{Kind: token.INT, Value: "1"}, + }, + iotaVal: 0, + expectErr: true, + }, + { + name: "unsupported right identifier", + expr: &ast.BinaryExpr{ + X: &ast.BasicLit{Kind: token.INT, Value: "1"}, + Op: token.ADD, + Y: &ast.Ident{Name: "unknown"}, + }, + iotaVal: 0, + expectErr: true, + }, + { + name: "invalid left literal", + expr: &ast.BinaryExpr{ + X: &ast.BasicLit{Kind: token.INT, Value: "invalid"}, + Op: token.ADD, + Y: &ast.Ident{Name: "iota"}, + }, + iotaVal: 0, + expectErr: true, + }, + { + name: "invalid right literal", + expr: &ast.BinaryExpr{ + X: &ast.Ident{Name: "iota"}, + Op: token.ADD, + Y: &ast.BasicLit{Kind: token.INT, Value: "invalid"}, + }, + iotaVal: 0, + expectErr: true, + }, + { + name: "unsupported left type", + expr: &ast.BinaryExpr{ + X: &ast.CallExpr{}, + Op: token.ADD, + Y: &ast.BasicLit{Kind: token.INT, Value: "1"}, + }, + iotaVal: 0, + expectErr: true, + }, + { + name: "unsupported right type", + expr: &ast.BinaryExpr{ + X: &ast.BasicLit{Kind: token.INT, Value: "1"}, + Op: token.ADD, + Y: &ast.CallExpr{}, + }, + iotaVal: 0, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, usesIota, err := EvaluateBinaryExpr(tt.expr, tt.iotaVal) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedVal, val) + assert.Equal(t, tt.expectedIota, usesIota) + } + }) + } +} + +func TestApplyIotaOperationNil(t *testing.T) { + gen, err := New("test", "") + require.NoError(t, err) + + // test nil operation returns iotaVal unchanged + result := gen.applyIotaOperation(nil, 42) + assert.Equal(t, 42, result) +} + +func TestDivisionByZeroInQUO(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + src := `package test +type divZero int +const ( + divZeroA = 10 / iota // division by zero when iota=0 +) +` + require.NoError(t, os.WriteFile(testFile, []byte(src), 0o644)) + + gen, err := New("divZero", "") + require.NoError(t, err) + err = gen.Parse(tmpDir) + require.NoError(t, err) + + // should handle division by zero gracefully + assert.Equal(t, 0, gen.values["divZeroA"].value) +} + +func TestInvalidUTF8CharacterLiteral(t *testing.T) { + // test ConvertLiteralToInt with a hex value that is valid + lit := &ast.BasicLit{ + Kind: token.CHAR, + Value: "'\\x80'", // this is handled correctly by strconv.Unquote + } + + val, err := ConvertLiteralToInt(lit) + require.Error(t, err) // should error because \x80 is not valid UTF-8 for a char + assert.Contains(t, err.Error(), "invalid UTF-8") + assert.Equal(t, 0, val) +} + +func TestMultipleCharactersInLiteral(t *testing.T) { + // test ConvertLiteralToInt with multiple characters + lit := &ast.BasicLit{ + Kind: token.CHAR, + Value: "'ab'", // invalid: multiple characters + } + + val, err := ConvertLiteralToInt(lit) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot parse character literal") + assert.Equal(t, 0, val) +} + +func TestGenerateWriteFileError(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + src := `package test +type writeErr int +const ( + writeErrA = iota + writeErrB +) +` + require.NoError(t, os.WriteFile(testFile, []byte(src), 0o644)) + + gen, err := New("writeErr", "/nonexistent/path/that/cannot/be/created/because/parent/does/not/exist") + require.NoError(t, err) + err = gen.Parse(tmpDir) + require.NoError(t, err) + + err = gen.Generate() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to create output directory") +} + +func TestEmptyValueSpec(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + // create a const block with type declaration but no names + src := `package test +type emptySpec int +const ( + emptySpecA = iota +) +` + require.NoError(t, os.WriteFile(testFile, []byte(src), 0o644)) + + gen, err := New("emptySpec", "") + require.NoError(t, err) + err = gen.Parse(tmpDir) + require.NoError(t, err) + + assert.Equal(t, 0, gen.values["emptySpecA"].value) +} + +func TestProcessExplicitValueDefaultReturn(t *testing.T) { + gen, err := New("test", "") + require.NoError(t, err) + + state := &constParseState{} + + // test with an unsupported expression type to trigger default return + expr := &ast.ParenExpr{} // unsupported type + result := gen.processExplicitValue(expr, state) + assert.Equal(t, 0, result) +} + +func TestApplyIotaOperationDefaultCase(t *testing.T) { + gen, err := New("test", "") + require.NoError(t, err) + + // test with unsupported operation to trigger default case + op := &iotaOperation{ + op: token.AND, // unsupported operation + operand: 5, + iotaOnLeft: true, + } + + result := gen.applyIotaOperation(op, 10) + assert.Equal(t, 10, result) // should return iotaVal unchanged +} + +func TestProcessBinaryExprError(t *testing.T) { + gen, err := New("test", "") + require.NoError(t, err) + + state := &constParseState{iotaVal: 5} + + // create an invalid binary expression + expr := &ast.BinaryExpr{ + X: &ast.FuncLit{}, // unsupported type + Op: token.ADD, + Y: &ast.BasicLit{Kind: token.INT, Value: "10"}, + } + + val, op := gen.processBinaryExpr(expr, state) + assert.Equal(t, 0, val) + assert.Nil(t, op) +} + +func TestRightSideDivisionByIota(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + src := `package test +type divByIota int +const ( + divByIotaA = iota // 0 + divByIotaB = 10 / iota // 10/1 = 10 + divByIotaC // 10/2 = 5 + divByIotaD // 10/3 = 3 +) +` + require.NoError(t, os.WriteFile(testFile, []byte(src), 0o644)) + + gen, err := New("divByIota", "") + require.NoError(t, err) + err = gen.Parse(tmpDir) + require.NoError(t, err) + + assert.Equal(t, 0, gen.values["divByIotaA"].value) + assert.Equal(t, 10, gen.values["divByIotaB"].value) + assert.Equal(t, 5, gen.values["divByIotaC"].value) + assert.Equal(t, 3, gen.values["divByIotaD"].value) +} + +func TestMultipleCharactersError(t *testing.T) { + // directly test the multiple characters check in ConvertLiteralToInt + // we need to craft a value that passes strconv.Unquote but has multiple runes + lit := &ast.BasicLit{ + Kind: token.CHAR, + Value: "'\\u0041\\u0042'", // 'AB' - two unicode characters + } + + val, err := ConvertLiteralToInt(lit) + require.Error(t, err) + assert.Contains(t, err.Error(), "character literal") + assert.Equal(t, 0, val) +} + +func TestWriteFilePermissionError(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.go") + src := `package test +type perm int +const ( + permA = iota + permB +) +` + require.NoError(t, os.WriteFile(testFile, []byte(src), 0o644)) + + // create a read-only directory + readOnlyDir := filepath.Join(tmpDir, "readonly") + require.NoError(t, os.MkdirAll(readOnlyDir, 0o755)) + + gen, err := New("perm", readOnlyDir) + require.NoError(t, err) + err = gen.Parse(tmpDir) + require.NoError(t, err) + + // make the directory read-only to cause write failure + require.NoError(t, os.Chmod(readOnlyDir, 0o555)) + defer os.Chmod(readOnlyDir, 0o755) // restore permissions for cleanup + + err = gen.Generate() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to write output file") +} + +func TestParseConstBlockWithImportSpec(t *testing.T) { + // test that parseConstBlock handles non-ValueSpec entries correctly + gen, err := New("test", "") + require.NoError(t, err) + gen.pkgName = "test" + + // create a GenDecl with an ImportSpec (not a ValueSpec) + decl := &ast.GenDecl{ + Tok: token.CONST, + Specs: []ast.Spec{ + &ast.ImportSpec{}, // this should be skipped + }, + } + + // this should not panic and should handle gracefully + gen.parseConstBlock(decl) + + // no values should be added + assert.Empty(t, gen.values) +} + +func TestApplyIotaOperationDivisionByZeroRightSide(t *testing.T) { + gen, err := New("test", "") + require.NoError(t, err) + + // test division when iota is 0 and iota is on the right side + op := &iotaOperation{ + op: token.QUO, + operand: 10, + iotaOnLeft: false, // operand / iota + } + + // when iota is 0, division by zero should return 0 + result := gen.applyIotaOperation(op, 0) + assert.Equal(t, 0, result) +} + +func TestConvertLiteralToIntMultipleRunes(t *testing.T) { + // test the case where strconv.Unquote returns an error + lit := &ast.BasicLit{ + Kind: token.CHAR, + Value: "'\\U00010000\\U00010001'", // invalid: two unicode code points + } + + val, err := ConvertLiteralToInt(lit) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot parse character literal") + assert.Equal(t, 0, val) +} diff --git a/internal/generator/testdata/binary_expr.go b/internal/generator/testdata/binary_expr.go index 9ca9ffa..9e9ecfe 100644 --- a/internal/generator/testdata/binary_expr.go +++ b/internal/generator/testdata/binary_expr.go @@ -6,4 +6,4 @@ const ( binaryExprFirst binaryExpr = iota + 1 binaryExprSecond binaryExprThird -) \ No newline at end of file +) diff --git a/internal/generator/testdata/edge_cases.go b/internal/generator/testdata/edge_cases.go new file mode 100644 index 0000000..7b5c924 --- /dev/null +++ b/internal/generator/testdata/edge_cases.go @@ -0,0 +1,27 @@ +package testdata + +// test multiplication and division +type mulDivType uint8 + +const ( + mulDivTypeA mulDivType = iota * 2 // 0 + mulDivTypeB // 2 + mulDivTypeC // 4 +) + +// test right-side iota +type rightIotaType uint8 + +const ( + rightIotaTypeX rightIotaType = 10 + iota // 10 + rightIotaTypeY // 11 +) + +// test subtraction +type subType int + +const ( + subTypeA subType = 100 - iota // 100 + subTypeB // 99 + subTypeC // 98 +) diff --git a/internal/generator/testdata/no_zero.go b/internal/generator/testdata/no_zero.go new file mode 100644 index 0000000..78232e8 --- /dev/null +++ b/internal/generator/testdata/no_zero.go @@ -0,0 +1,9 @@ +package testdata + +type noZero uint8 + +const ( + noZeroFirst noZero = 1 // Start at 1, no zero value + noZeroSecond noZero = 2 + noZeroThird noZero = 3 +) diff --git a/internal/generator/testdata/order_test.go b/internal/generator/testdata/order_test.go new file mode 100644 index 0000000..3770185 --- /dev/null +++ b/internal/generator/testdata/order_test.go @@ -0,0 +1,10 @@ +package testdata + +type orderTest uint8 + +const ( + orderTestZero orderTest = iota // Should be first (alphabetically would be fourth) + orderTestAlpha // Should be second (alphabetically would be first) + orderTestCharlie // Should be third (alphabetically would be third) + orderTestBravo // Should be fourth (alphabetically would be second) +) diff --git a/internal/generator/testdata/various_types.go b/internal/generator/testdata/various_types.go new file mode 100644 index 0000000..5259b88 --- /dev/null +++ b/internal/generator/testdata/various_types.go @@ -0,0 +1,30 @@ +package testdata + +// test various underlying types +type uint16Type uint16 + +const ( + uint16TypeFirst uint16Type = iota + uint16TypeSecond +) + +type int32Type int32 + +const ( + int32TypeAlpha int32Type = iota + 100 + int32TypeBeta +) + +type byteType byte + +const ( + byteTypeA byteType = iota + byteTypeB +) + +type runeType rune + +const ( + runeTypeX runeType = 'A' + runeTypeY runeType = 'B' +) diff --git a/main.go b/main.go index 0e87156..dadc4ec 100644 --- a/main.go +++ b/main.go @@ -16,8 +16,8 @@ var osExit = os.Exit func main() { typeFlag := flag.String("type", "", "type name (must be lowercase)") pathFlag := flag.String("path", "", "output directory path (default: same as source)") - lowerFlag := flag.Bool("lower", false, "use lower case for marshaled/unmarshaled values") - getterFlag := flag.Bool("getter", false, "generate getter methods for enum values") + lowerFlag := flag.Bool("lower", false, "use lowercase for string representation (e.g., 'active' instead of 'Active')") + getterFlag := flag.Bool("getter", false, "generate GetByID function to retrieve enum by integer value (requires unique IDs)") helpFlag := flag.Bool("help", false, "show usage") versionFlag := flag.Bool("version", false, "print version") flag.Parse() @@ -66,9 +66,7 @@ func main() { } func showUsage() { - fmt.Printf("usage: enumgen -type [-path ] [-lower] [-version]\n") - fmt.Printf(" -type type name (must be lowercase)\n") - fmt.Printf(" -path output directory path (default: same as source)\n") - fmt.Printf(" -lower use lower case for marshaled/unmarshaled values\n") - fmt.Printf(" -version print version\n") + fmt.Printf("usage: enum [flags]\n\n") + fmt.Printf("Flags:\n") + flag.PrintDefaults() } diff --git a/main_test.go b/main_test.go index b582ab9..534dbf6 100644 --- a/main_test.go +++ b/main_test.go @@ -166,4 +166,82 @@ const ( main() assert.Equal(t, 1, exitCode) }) + + t.Run("parse error - no matching constants", func(t *testing.T) { + // reset flags for this run + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + + origArgs := os.Args + origWd, err := os.Getwd() + require.NoError(t, err) + defer func() { + os.Args = origArgs + require.NoError(t, os.Chdir(origWd)) + }() + + tmpDir := t.TempDir() + + // create a file without matching constants + err = os.WriteFile(filepath.Join(tmpDir, "status.go"), []byte(` +package test +type status uint8 +const ( + SomethingElse = 1 +) +`), 0o644) + require.NoError(t, err) + + // change working directory to temp dir + require.NoError(t, os.Chdir(tmpDir)) + + var exitCode int + osExit = func(code int) { exitCode = code } + + os.Args = []string{"app", "-type", "status"} + main() + assert.Equal(t, 1, exitCode) + }) + + t.Run("generate error - invalid path", func(t *testing.T) { + // reset flags for this run + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + + origArgs := os.Args + origWd, err := os.Getwd() + require.NoError(t, err) + defer func() { + os.Args = origArgs + require.NoError(t, os.Chdir(origWd)) + }() + + tmpDir := t.TempDir() + + // create a valid file + err = os.WriteFile(filepath.Join(tmpDir, "status.go"), []byte(` +package test +type status uint8 +const ( + statusActive status = iota + statusInactive +) +`), 0o644) + require.NoError(t, err) + + // change working directory to temp dir + require.NoError(t, os.Chdir(tmpDir)) + + // make the directory read-only to cause generate to fail + require.NoError(t, os.Chmod(tmpDir, 0o555)) + defer func() { + // restore permissions so cleanup can work + os.Chmod(tmpDir, 0o755) + }() + + var exitCode int + osExit = func(code int) { exitCode = code } + + os.Args = []string{"app", "-type", "status"} + main() + assert.Equal(t, 1, exitCode) + }) }