diff --git a/README.md b/README.md index 6c7055c..b5d64ac 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ go generate ./... ### Generator Options -- `-type` (required): the name of the type to generate enum for (must be lowercase) +- `-type` (required): the name of the type to generate enum for (must be private) - `-path`: output directory path (default: same as source) - `-lower`: use lowercase for marshaled/unmarshaled values - `-version`: print version information diff --git a/_examples/status/job_status_enum.go b/_examples/status/job_status_enum.go new file mode 100644 index 0000000..970a5f7 --- /dev/null +++ b/_examples/status/job_status_enum.go @@ -0,0 +1,113 @@ +// Code generated by enum generator; DO NOT EDIT. +package status + +import ( + "fmt" + + "database/sql/driver" +) + +// JobStatus is the exported type for the enum +type JobStatus struct { + name string + value int +} + +func (e JobStatus) String() string { return e.name } + +// MarshalText implements encoding.TextMarshaler +func (e JobStatus) MarshalText() ([]byte, error) { + return []byte(e.name), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler +func (e *JobStatus) UnmarshalText(text []byte) error { + var err error + *e, err = ParseJobStatus(string(text)) + return err +} + +// Value implements the driver.Valuer interface +func (e JobStatus) Value() (driver.Value, error) { + return e.name, nil +} + +// Scan implements the sql.Scanner interface +func (e *JobStatus) Scan(value interface{}) error { + if value == nil { + *e = JobStatusValues()[0] + return nil + } + + str, ok := value.(string) + if !ok { + if b, ok := value.([]byte); ok { + str = string(b) + } else { + return fmt.Errorf("invalid jobStatus value: %v", value) + } + } + + val, err := ParseJobStatus(str) + if err != nil { + return err + } + + *e = val + return nil +} + +// ParseJobStatus converts string to jobStatus enum value +func ParseJobStatus(v string) (JobStatus, error) { + + switch v { + case "active": + return JobStatusActive, nil + case "blocked": + return JobStatusBlocked, nil + case "inactive": + return JobStatusInactive, nil + case "unknown": + return JobStatusUnknown, nil + + } + + return JobStatus{}, fmt.Errorf("invalid jobStatus: %s", v) +} + +// MustJobStatus is like ParseJobStatus but panics if string is invalid +func MustJobStatus(v string) JobStatus { + r, err := ParseJobStatus(v) + if err != nil { + panic(err) + } + return r +} + +// Public constants for jobStatus values +var ( + JobStatusActive = JobStatus{name: "active", value: 1} + JobStatusBlocked = JobStatus{name: "blocked", value: 3} + JobStatusInactive = JobStatus{name: "inactive", value: 2} + JobStatusUnknown = JobStatus{name: "unknown", value: 0} +) + +// JobStatusValues returns all possible enum values +func JobStatusValues() []JobStatus { + return []JobStatus{ + JobStatusActive, + JobStatusBlocked, + JobStatusInactive, + JobStatusUnknown, + } +} + +// JobStatusNames returns all possible enum names +func JobStatusNames() []string { + return []string{ + "active", + "blocked", + "inactive", + "unknown", + } +} diff --git a/_examples/status/job_status_test.go b/_examples/status/job_status_test.go new file mode 100644 index 0000000..a39f60a --- /dev/null +++ b/_examples/status/job_status_test.go @@ -0,0 +1,116 @@ +package status + +import ( + "database/sql" + "encoding/json" + "fmt" + "testing" + + _ "modernc.org/sqlite" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJobStatus(t *testing.T) { + + t.Run("basic", func(t *testing.T) { + s := JobStatusActive + assert.Equal(t, "active", s.String()) + }) + + t.Run("json", func(t *testing.T) { + type Data struct { + Status JobStatus `json:"status"` + } + + d := Data{Status: JobStatusActive} + b, err := json.Marshal(d) + require.NoError(t, err) + assert.Equal(t, `{"status":"active"}`, string(b)) + + var d2 Data + err = json.Unmarshal([]byte(`{"status":"inactive"}`), &d2) + require.NoError(t, err) + assert.Equal(t, JobStatusInactive, d2.Status) + }) + + t.Run("sql", func(t *testing.T) { + s := JobStatusActive + + // test Value() method + v, err := s.Value() + require.NoError(t, err) + assert.Equal(t, "active", v) + + // test Scan from string + var s2 JobStatus + err = s2.Scan("inactive") + require.NoError(t, err) + assert.Equal(t, JobStatusInactive, s2) + + // test Scan from []byte + err = s2.Scan([]byte("blocked")) + require.NoError(t, err) + assert.Equal(t, JobStatusBlocked, s2) + + // test Scan from nil - should get first value from StatusValues() + err = s2.Scan(nil) + require.NoError(t, err) + assert.Equal(t, JobStatusValues()[0], s2) + + // test invalid value + err = s2.Scan(123) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid jobStatus value") + }) + + t.Run("sqlite", func(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer db.Close() + + // create table with status column + _, err = db.Exec(`CREATE TABLE test_status (id INTEGER PRIMARY KEY, status TEXT)`) + require.NoError(t, err) + + // insert different status values + statuses := []JobStatus{JobStatusActive, JobStatusInactive, JobStatusBlocked} + for i, s := range statuses { + _, err = db.Exec(`INSERT INTO test_status (id, status) VALUES (?, ?)`, i+1, s) + require.NoError(t, err) + } + + // insert nil status + _, err = db.Exec(`INSERT INTO test_status (id, status) VALUES (?, ?)`, 4, nil) + require.NoError(t, err) + + // read and verify each status + for i, expected := range statuses { + var s JobStatus + err = db.QueryRow(`SELECT status FROM test_status WHERE id = ?`, i+1).Scan(&s) + require.NoError(t, err) + assert.Equal(t, expected, s) + } + + // verify nil status gets first value + var s JobStatus + err = db.QueryRow(`SELECT status FROM test_status WHERE id = 4`).Scan(&s) + require.NoError(t, err) + assert.Equal(t, JobStatusValues()[0], s) + }) + + t.Run("invalid", func(t *testing.T) { + var d struct { + Status JobStatus `json:"status"` + } + err := json.Unmarshal([]byte(`{"status":"invalid"}`), &d) + assert.Error(t, err) + }) +} + +func ExampleJobStatus() { + s := JobStatusActive + fmt.Println(s.String()) + // Output: active +} diff --git a/_examples/status/status.go b/_examples/status/status.go index fdfdc49..fe6386b 100644 --- a/_examples/status/status.go +++ b/_examples/status/status.go @@ -1,6 +1,7 @@ package status //go:generate go run ../../main.go -type status -lower +//go:generate go run ../../main.go -type jobStatus -lower type status uint8 @@ -10,3 +11,12 @@ const ( statusInactive statusBlocked ) + +type jobStatus uint8 + +const ( + jobStatusUnknown jobStatus = iota + jobStatusActive + jobStatusInactive + jobStatusBlocked +) diff --git a/internal/generator/generator.go b/internal/generator/generator.go index b2b9bf8..697ce1d 100644 --- a/internal/generator/generator.go +++ b/internal/generator/generator.go @@ -15,12 +15,13 @@ import ( "strings" "text/template" "unicode" + "unicode/utf8" "golang.org/x/text/cases" "golang.org/x/text/language" ) -var titleCaser = cases.Title(language.English) +var titleCaser = cases.Title(language.English, cases.NoLower) // Generator holds the data needed for enum code generation type Generator struct { @@ -44,8 +45,8 @@ func New(typeName, path string) (*Generator, error) { if typeName == "" { return nil, fmt.Errorf("type name is required") } - if strings.ToLower(typeName) != typeName { - return nil, fmt.Errorf("type name must be lowercase (private)") + if !unicode.IsLower(rune(typeName[0])) { + return nil, fmt.Errorf("first letter must be lowercase (private)") } return &Generator{ @@ -199,7 +200,7 @@ func (g *Generator) Generate() error { } // write generated code to file - outputName := filepath.Join(g.Path, g.Type+"_enum.go") + outputName := filepath.Join(g.Path, getFileNameForType(g.Type)) if err := os.WriteFile(outputName, src, 0o600); err != nil { return fmt.Errorf("failed to write output file: %w", err) } @@ -207,6 +208,52 @@ func (g *Generator) Generate() error { return nil } +// splitCamelCase splits a camel case string into words, it handles the sequential abbreviations +// and acronyms by treating them as single words. +// For example: +// "jobStatus" becomes ["job", "Status"]. +// "internalIPAddress" becomes ["internal", "IP", "Address"]. +// "internalIP" becomes ["internal", "IP"]. +// "HTTPResponse" becomes ["HTTP", "Response"]. +// "HTTP" is not split further. +func splitCamelCase(s string) []string { + var words []string + start := 0 + var prev rune + for i, curr := range s { + if i == 0 { + prev = curr + continue + } + _, width := utf8.DecodeRuneInString(s[i:]) + var next *rune + if i+width < len(s) { + nextr, _ := utf8.DecodeRuneInString(s[i+width:]) + next = &nextr + } + if (unicode.IsLower(prev) && unicode.IsUpper(curr)) || + (unicode.IsUpper(curr) && (next != nil && unicode.IsLower(*next))) { + words = append(words, s[start:i]) + start = i + } + prev = curr + } + words = append(words, s[start:]) + return words +} + +// getFileNameForType returns the file name for the generated enum code based on the type name. +// It converts the type name to snake case and appends "_enum.go" to it. +// For example, if the type name is "jobStatus", the file name will be "job_status_enum.go". +func getFileNameForType(typeName string) string { + words := splitCamelCase(typeName) + for i := range words { + words[i] = strings.ToLower(words[i]) + } + + return strings.Join(words, "_") + "_enum.go" +} + // isValidGoIdentifier checks if a string is a valid Go identifier: // - must start with a letter or underscore // - can contain letters, digits, and underscores diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index c8c7d18..5dc1d5d 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -25,6 +25,10 @@ func TestGenerator(t *testing.T) { require.NoError(t, err) assert.NotNil(t, gen) + gen, err = New("moreComplexType", "") + require.NoError(t, err) + assert.NotNil(t, gen) + // check if generated code is valid Go code tmpDir := t.TempDir() gen, err = New("status", tmpDir) @@ -98,6 +102,32 @@ func TestGenerator(t *testing.T) { assert.Contains(t, string(content), "StatusBlocked") }) + t.Run("parse and generate with complex name", func(t *testing.T) { + // create temp dir for output + tmpDir := t.TempDir() + + gen, err := New("jobStatus", tmpDir) + require.NoError(t, err) + + // parse testdata + err = gen.Parse("testdata") + require.NoError(t, err) + + // generate + err = gen.Generate() + require.NoError(t, err) + + // verify file was created + content, err := os.ReadFile(filepath.Join(tmpDir, "job_status_enum.go")) + require.NoError(t, err) + + // check content + assert.Contains(t, string(content), "type JobStatus struct") + assert.Contains(t, string(content), "JobStatusActive") + assert.Contains(t, string(content), "JobStatusInactive") + assert.Contains(t, string(content), "JobStatusBlocked") + }) + t.Run("sql support", func(t *testing.T) { tmpDir := t.TempDir() gen, err := New("status", tmpDir) @@ -375,3 +405,41 @@ const name string assert.Contains(t, err.Error(), "no const values found for type status") }) } + +func TestSplitCamelCase(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + {"", []string{""}}, + {"status", []string{"status"}}, + {"internalIPAddress", []string{"internal", "IP", "Address"}}, + {"internalIP", []string{"internal", "IP"}}, + {"HTTP", []string{"HTTP"}}, + {"HTTPResponseCode", []string{"HTTP", "Response", "Code"}}, + } + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + result := splitCamelCase(test.input) + assert.Equal(t, test.expected, result) + }) + } +} + +func TestGetFileNameForType(t *testing.T) { + tests := []struct { + typeName string + expected string + }{ + {"status", "status_enum.go"}, + {"jobStatus", "job_status_enum.go"}, + } + + for _, test := range tests { + t.Run(test.typeName, func(t *testing.T) { + result := getFileNameForType(test.typeName) + assert.Equal(t, test.expected, result) + }) + } +} diff --git a/internal/generator/testdata/job_status.go b/internal/generator/testdata/job_status.go new file mode 100644 index 0000000..2229dfb --- /dev/null +++ b/internal/generator/testdata/job_status.go @@ -0,0 +1,10 @@ +package testdata + +type jobStatus uint8 + +const ( + jobStatusUnknown jobStatus = iota + jobStatusActive + jobStatusInactive + jobStatusBlocked +)