Skip to content

Commit

Permalink
feat: make encoding configurable in initdb (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzFelix committed May 10, 2024
1 parent 74c945b commit c00e987
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 8 deletions.
7 changes: 7 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Config struct {
dataPath string
binariesPath string
locale string
encoding string
startParameters map[string]string
binaryRepositoryURL string
startTimeout time.Duration
Expand Down Expand Up @@ -110,6 +111,12 @@ func (c Config) Locale(locale string) Config {
return c
}

// Encoding sets the default character set for initdb
func (c Config) Encoding(encoding string) Config {
c.encoding = encoding
return c
}

// StartParameters sets run-time parameters when starting Postgres (passed to Postgres via "-c").
//
// These parameters can be used to override the default configuration values in postgres.conf such
Expand Down
2 changes: 1 addition & 1 deletion embedded_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (ep *EmbeddedPostgres) cleanDataDirectoryAndInit() error {
return fmt.Errorf("unable to clean up data directory %s with error: %s", ep.config.dataPath, err)
}

if err := ep.initDatabase(ep.config.binariesPath, ep.config.runtimePath, ep.config.dataPath, ep.config.username, ep.config.password, ep.config.locale, ep.syncedLogger.file); err != nil {
if err := ep.initDatabase(ep.config.binariesPath, ep.config.runtimePath, ep.config.dataPath, ep.config.username, ep.config.password, ep.config.locale, ep.config.encoding, ep.syncedLogger.file); err != nil {
return err
}

Expand Down
35 changes: 33 additions & 2 deletions embedded_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func Test_ErrorWhenUnableToInitDatabase(t *testing.T) {
return jarFile, true
}

database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, logger *os.File) error {
database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, encoding string, logger *os.File) error {
return errors.New("ah it did not work")
}

Expand Down Expand Up @@ -226,7 +226,7 @@ func Test_ErrorWhenCannotStartPostgresProcess(t *testing.T) {
return jarFile, true
}

database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, logger *os.File) error {
database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, encoding string, logger *os.File) error {
_, _ = logger.Write([]byte("ah it did not work"))
return nil
}
Expand Down Expand Up @@ -257,6 +257,7 @@ func Test_CustomConfig(t *testing.T) {
Port(9876).
StartTimeout(10 * time.Second).
Locale("C").
Encoding("UTF8").
Logger(nil))

if err := database.Start(); err != nil {
Expand Down Expand Up @@ -356,6 +357,36 @@ func Test_CustomLocaleConfig(t *testing.T) {
}
}

func Test_CustomEncodingConfig(t *testing.T) {
database := NewDatabase(DefaultConfig().Encoding("UTF8"))
if err := database.Start(); err != nil {
shutdownDBAndFail(t, err, database)
}

db, err := sql.Open("postgres", "host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")
if err != nil {
shutdownDBAndFail(t, err, database)
}

rows := db.QueryRow("SHOW SERVER_ENCODING;")

var (
value string
)
if err := rows.Scan(&value); err != nil {
shutdownDBAndFail(t, err, database)
}
assert.Equal(t, "UTF8", value)

if err := db.Close(); err != nil {
shutdownDBAndFail(t, err, database)
}

if err := database.Stop(); err != nil {
shutdownDBAndFail(t, err, database)
}
}

func Test_ConcurrentStart(t *testing.T) {
var wg sync.WaitGroup

Expand Down
8 changes: 6 additions & 2 deletions prepare_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ const (
fmtAfterError = "%v happened after error: %w"
)

type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger *os.File) error
type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error
type createDatabase func(port uint32, username, password, database string) error

func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger *os.File) error {
func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, encoding string, logger *os.File) error {
passwordFile, err := createPasswordFile(runtimePath, password)
if err != nil {
return err
Expand All @@ -38,6 +38,10 @@ func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username
args = append(args, fmt.Sprintf("--locale=%s", locale))
}

if encoding != "" {
args = append(args, fmt.Sprintf("--encoding=%s", encoding))
}

postgresInitDBBinary := filepath.Join(binaryExtractLocation, "bin/initdb")
postgresInitDBProcess := exec.Command(postgresInitDBBinary, args...)
postgresInitDBProcess.Stderr = logger
Expand Down
27 changes: 24 additions & 3 deletions prepare_database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

func Test_defaultInitDatabase_ErrorWhenCannotCreatePasswordFile(t *testing.T) {
err := defaultInitDatabase("path_not_exists", "path_not_exists", "path_not_exists", "Tom", "Beer", "", os.Stderr)
err := defaultInitDatabase("path_not_exists", "path_not_exists", "path_not_exists", "Tom", "Beer", "", "", os.Stderr)

assert.EqualError(t, err, "unable to write password file to path_not_exists/pwfile")
}
Expand Down Expand Up @@ -49,7 +49,7 @@ func Test_defaultInitDatabase_ErrorWhenCannotStartInitDBProcess(t *testing.T) {

_, _ = logFile.Write([]byte("and here are the logs!"))

err = defaultInitDatabase(binTempDir, runtimeTempDir, filepath.Join(runtimeTempDir, "data"), "Tom", "Beer", "", logFile)
err = defaultInitDatabase(binTempDir, runtimeTempDir, filepath.Join(runtimeTempDir, "data"), "Tom", "Beer", "", "", logFile)

assert.NotNil(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("unable to init database using '%s/bin/initdb -A password -U Tom -D %s/data --pwfile=%s/pwfile'",
Expand All @@ -72,7 +72,7 @@ func Test_defaultInitDatabase_ErrorInvalidLocaleSetting(t *testing.T) {
}
}()

err = defaultInitDatabase(tempDir, tempDir, filepath.Join(tempDir, "data"), "postgres", "postgres", "en_XY", os.Stderr)
err = defaultInitDatabase(tempDir, tempDir, filepath.Join(tempDir, "data"), "postgres", "postgres", "en_XY", "", os.Stderr)

assert.NotNil(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("unable to init database using '%s/bin/initdb -A password -U postgres -D %s/data --pwfile=%s/pwfile --locale=en_XY'",
Expand All @@ -81,6 +81,27 @@ func Test_defaultInitDatabase_ErrorInvalidLocaleSetting(t *testing.T) {
tempDir))
}

func Test_defaultInitDatabase_ErrorInvalidEncodingSetting(t *testing.T) {
tempDir, err := os.MkdirTemp("", "prepare_database_test")
if err != nil {
panic(err)
}

defer func() {
if err := os.RemoveAll(tempDir); err != nil {
panic(err)
}
}()

err = defaultInitDatabase(tempDir, tempDir, filepath.Join(tempDir, "data"), "postgres", "postgres", "", "invalid", os.Stderr)

assert.NotNil(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("unable to init database using '%s/bin/initdb -A password -U postgres -D %s/data --pwfile=%s/pwfile --encoding=invalid'",
tempDir,
tempDir,
tempDir))
}

func Test_defaultInitDatabase_PwFileRemoved(t *testing.T) {
tempDir, err := os.MkdirTemp("", "prepare_database_test")
if err != nil {
Expand Down

0 comments on commit c00e987

Please sign in to comment.