Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ go generate ./...

### Generator Options

- `-type` (required): the name of the type to generate enum for (must be lowercase)
- `-type` (required): the name of the type to generate enum for (must be private)
- `-path`: output directory path (default: same as source)
- `-lower`: use lowercase for marshaled/unmarshaled values
- `-version`: print version information
Expand Down
113 changes: 113 additions & 0 deletions _examples/status/job_status_enum.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

116 changes: 116 additions & 0 deletions _examples/status/job_status_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package status

import (
"database/sql"
"encoding/json"
"fmt"
"testing"

_ "modernc.org/sqlite"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestJobStatus(t *testing.T) {

t.Run("basic", func(t *testing.T) {
s := JobStatusActive
assert.Equal(t, "active", s.String())
})

t.Run("json", func(t *testing.T) {
type Data struct {
Status JobStatus `json:"status"`
}

d := Data{Status: JobStatusActive}
b, err := json.Marshal(d)
require.NoError(t, err)
assert.Equal(t, `{"status":"active"}`, string(b))

var d2 Data
err = json.Unmarshal([]byte(`{"status":"inactive"}`), &d2)
require.NoError(t, err)
assert.Equal(t, JobStatusInactive, d2.Status)
})

t.Run("sql", func(t *testing.T) {
s := JobStatusActive

// test Value() method
v, err := s.Value()
require.NoError(t, err)
assert.Equal(t, "active", v)

// test Scan from string
var s2 JobStatus
err = s2.Scan("inactive")
require.NoError(t, err)
assert.Equal(t, JobStatusInactive, s2)

// test Scan from []byte
err = s2.Scan([]byte("blocked"))
require.NoError(t, err)
assert.Equal(t, JobStatusBlocked, s2)

// test Scan from nil - should get first value from StatusValues()
err = s2.Scan(nil)
require.NoError(t, err)
assert.Equal(t, JobStatusValues()[0], s2)

// test invalid value
err = s2.Scan(123)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid jobStatus value")
})

t.Run("sqlite", func(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
require.NoError(t, err)
defer db.Close()

// create table with status column
_, err = db.Exec(`CREATE TABLE test_status (id INTEGER PRIMARY KEY, status TEXT)`)
require.NoError(t, err)

// insert different status values
statuses := []JobStatus{JobStatusActive, JobStatusInactive, JobStatusBlocked}
for i, s := range statuses {
_, err = db.Exec(`INSERT INTO test_status (id, status) VALUES (?, ?)`, i+1, s)
require.NoError(t, err)
}

// insert nil status
_, err = db.Exec(`INSERT INTO test_status (id, status) VALUES (?, ?)`, 4, nil)
require.NoError(t, err)

// read and verify each status
for i, expected := range statuses {
var s JobStatus
err = db.QueryRow(`SELECT status FROM test_status WHERE id = ?`, i+1).Scan(&s)
require.NoError(t, err)
assert.Equal(t, expected, s)
}

// verify nil status gets first value
var s JobStatus
err = db.QueryRow(`SELECT status FROM test_status WHERE id = 4`).Scan(&s)
require.NoError(t, err)
assert.Equal(t, JobStatusValues()[0], s)
})

t.Run("invalid", func(t *testing.T) {
var d struct {
Status JobStatus `json:"status"`
}
err := json.Unmarshal([]byte(`{"status":"invalid"}`), &d)
assert.Error(t, err)
})
}

func ExampleJobStatus() {
s := JobStatusActive
fmt.Println(s.String())
// Output: active
}
10 changes: 10 additions & 0 deletions _examples/status/status.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package status

//go:generate go run ../../main.go -type status -lower
//go:generate go run ../../main.go -type jobStatus -lower

type status uint8

Expand All @@ -10,3 +11,12 @@ const (
statusInactive
statusBlocked
)

type jobStatus uint8

const (
jobStatusUnknown jobStatus = iota
jobStatusActive
jobStatusInactive
jobStatusBlocked
)
55 changes: 51 additions & 4 deletions internal/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ import (
"strings"
"text/template"
"unicode"
"unicode/utf8"

"golang.org/x/text/cases"
"golang.org/x/text/language"
)

var titleCaser = cases.Title(language.English)
var titleCaser = cases.Title(language.English, cases.NoLower)

// Generator holds the data needed for enum code generation
type Generator struct {
Expand All @@ -44,8 +45,8 @@ func New(typeName, path string) (*Generator, error) {
if typeName == "" {
return nil, fmt.Errorf("type name is required")
}
if strings.ToLower(typeName) != typeName {
return nil, fmt.Errorf("type name must be lowercase (private)")
if !unicode.IsLower(rune(typeName[0])) {
return nil, fmt.Errorf("first letter must be lowercase (private)")
}

return &Generator{
Expand Down Expand Up @@ -199,14 +200,60 @@ func (g *Generator) Generate() error {
}

// write generated code to file
outputName := filepath.Join(g.Path, g.Type+"_enum.go")
outputName := filepath.Join(g.Path, getFileNameForType(g.Type))
if err := os.WriteFile(outputName, src, 0o600); err != nil {
return fmt.Errorf("failed to write output file: %w", err)
}

return nil
}

// splitCamelCase splits a camel case string into words, it handles the sequential abbreviations
// and acronyms by treating them as single words.
// For example:
// "jobStatus" becomes ["job", "Status"].
// "internalIPAddress" becomes ["internal", "IP", "Address"].
// "internalIP" becomes ["internal", "IP"].
// "HTTPResponse" becomes ["HTTP", "Response"].
// "HTTP" is not split further.
func splitCamelCase(s string) []string {
var words []string
start := 0
var prev rune
for i, curr := range s {
if i == 0 {
prev = curr
continue
}
_, width := utf8.DecodeRuneInString(s[i:])
var next *rune
if i+width < len(s) {
nextr, _ := utf8.DecodeRuneInString(s[i+width:])
next = &nextr
}
if (unicode.IsLower(prev) && unicode.IsUpper(curr)) ||
(unicode.IsUpper(curr) && (next != nil && unicode.IsLower(*next))) {
words = append(words, s[start:i])
start = i
}
prev = curr
}
words = append(words, s[start:])
return words
}

// getFileNameForType returns the file name for the generated enum code based on the type name.
// It converts the type name to snake case and appends "_enum.go" to it.
// For example, if the type name is "jobStatus", the file name will be "job_status_enum.go".
func getFileNameForType(typeName string) string {
words := splitCamelCase(typeName)
for i := range words {
words[i] = strings.ToLower(words[i])
}

return strings.Join(words, "_") + "_enum.go"
}

// isValidGoIdentifier checks if a string is a valid Go identifier:
// - must start with a letter or underscore
// - can contain letters, digits, and underscores
Expand Down
Loading