From e846ac4f6dcdc3d02be01e3923aec06ec88e3d49 Mon Sep 17 00:00:00 2001 From: bearsh Date: Tue, 2 Mar 2021 10:56:52 +0100 Subject: [PATCH] make persistence and path of db data directory configurable (#19) * add support for persistent database data - config: add field and setter for db data path - do not delete/reinit db data dir if path is set - data dir must exist and the pg (major) version must match used postgres version - tests: adjust existing test to new initDatabase function signature - tests: add test for resuse data option - platform test: test PG_VERSION file to match requested version - add doc about DataPath config option --- README.md | 7 +++ config.go | 11 ++++- embedded_postgres.go | 73 ++++++++++++++++++++++-------- embedded_postgres_test.go | 83 +++++++++++++++++++++++++++++++++- platform-test/platform_test.go | 27 ++++++++++- prepare_database.go | 6 +-- prepare_database_test.go | 6 +-- 7 files changed, 183 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 45f6f27..91f9323 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,16 @@ This library aims to require as little configuration as possible, favouring over | Database | postgres | | Version | 12.1.0 | | RuntimePath | $USER_HOME/.embedded-postgres-go/extracted | +| DataPath | empty, results in *RuntimePath*/data | | Port | 5432 | | StartTimeout | 15 Seconds | +The *RuntimePath* directory is erased and recreated on each start and therefor not +suitable for persistent data. If a persistent data location is required, set +*DataPath* to a directory outside of *RuntimePath*. +If the *RuntimePath* directory is empty or already initialized but with an incompatible +postgres version, it gets cleared and reinitialized. + A single Postgres instance can be created, started and stopped as follows ```go postgres := embeddedpostgres.NewDatabase() diff --git a/config.go b/config.go index feb83e2..5818a2e 100644 --- a/config.go +++ b/config.go @@ -14,6 +14,7 @@ type Config struct { username string password string runtimePath string + dataPath string locale string startTimeout time.Duration logger io.Writer @@ -69,12 +70,20 @@ func (c Config) Password(password string) Config { return c } -// RuntimePath sets the path that will be used for the extracted Postgres runtime and data directory. +// RuntimePath sets the path that will be used for the extracted Postgres runtime directory. +// If Postgres data directory is not set with DataPath(), this directory is also used as data directory. func (c Config) RuntimePath(path string) Config { c.runtimePath = path return c } +// DataPath sets the path that will be used for the Postgres data directory. +// If this option is set, a previously initialized data directory will be reused if possible. +func (c Config) DataPath(path string) Config { + c.dataPath = path + return c +} + // Locale sets the default locale for initdb func (c Config) Locale(locale string) Config { c.locale = locale diff --git a/embedded_postgres.go b/embedded_postgres.go index baa8946..ae46be5 100644 --- a/embedded_postgres.go +++ b/embedded_postgres.go @@ -3,12 +3,13 @@ package embeddedpostgres import ( "errors" "fmt" - "io" + "io/ioutil" "log" "net" "os" "os/exec" "path/filepath" + "strings" "github.com/mholt/archiver/v3" ) @@ -51,6 +52,7 @@ func newDatabaseWithConfig(config Config) *EmbeddedPostgres { // Start will try to start the configured Postgres process returning an error when there were any problems with invocation. // If any error occurs Start will try to also Stop the Postgres process in order to not leave any sub-process running. +//nolint:funlen func (ep *EmbeddedPostgres) Start() error { if ep.started { return errors.New("server is already started") @@ -67,17 +69,27 @@ func (ep *EmbeddedPostgres) Start() error { } } - binaryExtractLocation := userLocationOrDefault(ep.config.runtimePath, cacheLocation) + binaryExtractLocation := userRuntimePathOrDefault(ep.config.runtimePath, cacheLocation) if err := os.RemoveAll(binaryExtractLocation); err != nil { - return fmt.Errorf("unable to clean up directory %s with error: %s", binaryExtractLocation, err) + return fmt.Errorf("unable to clean up runtime directory %s with error: %s", binaryExtractLocation, err) } if err := archiver.NewTarXz().Unarchive(cacheLocation, binaryExtractLocation); err != nil { return fmt.Errorf("unable to extract postgres archive %s to %s", cacheLocation, binaryExtractLocation) } - if err := ep.initDatabase(binaryExtractLocation, ep.config.username, ep.config.password, ep.config.locale, ep.config.logger); err != nil { - return err + dataLocation := userDataPathOrDefault(ep.config.dataPath, binaryExtractLocation) + + reuseData := ep.config.dataPath != "" && dataDirIsValid(dataLocation, ep.config.version) + + if !reuseData { + if err := os.RemoveAll(dataLocation); err != nil { + return fmt.Errorf("unable to clean up data directory %s with error: %s", dataLocation, err) + } + + if err := ep.initDatabase(binaryExtractLocation, dataLocation, ep.config.username, ep.config.password, ep.config.locale, ep.config.logger); err != nil { + return err + } } if err := startPostgres(binaryExtractLocation, ep.config); err != nil { @@ -86,16 +98,18 @@ func (ep *EmbeddedPostgres) Start() error { ep.started = true - if err := ep.createDatabase(ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil { - if stopErr := stopPostgres(binaryExtractLocation, ep.config.logger); stopErr != nil { - return fmt.Errorf("unable to stop database casused by error %s", err) - } + if !reuseData { + if err := ep.createDatabase(ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil { + if stopErr := stopPostgres(binaryExtractLocation, ep.config); stopErr != nil { + return fmt.Errorf("unable to stop database casused by error %s", err) + } - return err + return err + } } if err := healthCheckDatabaseOrTimeout(ep.config); err != nil { - if stopErr := stopPostgres(binaryExtractLocation, ep.config.logger); stopErr != nil { + if stopErr := stopPostgres(binaryExtractLocation, ep.config); stopErr != nil { return fmt.Errorf("unable to stop database casused by error %s", err) } @@ -112,8 +126,8 @@ func (ep *EmbeddedPostgres) Stop() error { return errors.New("server has not been started") } - binaryExtractLocation := userLocationOrDefault(ep.config.runtimePath, cacheLocation) - if err := stopPostgres(binaryExtractLocation, ep.config.logger); err != nil { + binaryExtractLocation := userRuntimePathOrDefault(ep.config.runtimePath, cacheLocation) + if err := stopPostgres(binaryExtractLocation, ep.config); err != nil { return err } @@ -125,7 +139,7 @@ func (ep *EmbeddedPostgres) Stop() error { func startPostgres(binaryExtractLocation string, config Config) error { postgresBinary := filepath.Join(binaryExtractLocation, "bin/pg_ctl") postgresProcess := exec.Command(postgresBinary, "start", "-w", - "-D", filepath.Join(binaryExtractLocation, "data"), + "-D", userDataPathOrDefault(config.dataPath, binaryExtractLocation), "-o", fmt.Sprintf(`"-p %d"`, config.port)) log.Println(postgresProcess.String()) postgresProcess.Stderr = config.logger @@ -138,12 +152,12 @@ func startPostgres(binaryExtractLocation string, config Config) error { return nil } -func stopPostgres(binaryExtractLocation string, logger io.Writer) error { +func stopPostgres(binaryExtractLocation string, config Config) error { postgresBinary := filepath.Join(binaryExtractLocation, "bin/pg_ctl") postgresProcess := exec.Command(postgresBinary, "stop", "-w", - "-D", filepath.Join(binaryExtractLocation, "data")) - postgresProcess.Stderr = logger - postgresProcess.Stdout = logger + "-D", userDataPathOrDefault(config.dataPath, binaryExtractLocation)) + postgresProcess.Stderr = config.logger + postgresProcess.Stdout = config.logger return postgresProcess.Run() } @@ -161,10 +175,31 @@ func ensurePortAvailable(port uint32) error { return nil } -func userLocationOrDefault(userLocation, cacheLocation string) string { +func userRuntimePathOrDefault(userLocation, cacheLocation string) string { if userLocation != "" { return userLocation } return filepath.Join(filepath.Dir(cacheLocation), "extracted") } + +func userDataPathOrDefault(userLocation, runtimeLocation string) string { + if userLocation != "" { + return userLocation + } + + return filepath.Join(runtimeLocation, "data") +} + +func dataDirIsValid(dataDir string, version PostgresVersion) bool { + pgVersion := filepath.Join(dataDir, "PG_VERSION") + + d, err := ioutil.ReadFile(pgVersion) + if err != nil { + return false + } + + v := strings.TrimSuffix(string(d), "\n") + + return strings.HasPrefix(string(version), v) +} diff --git a/embedded_postgres_test.go b/embedded_postgres_test.go index 8e848e6..d138219 100644 --- a/embedded_postgres_test.go +++ b/embedded_postgres_test.go @@ -118,7 +118,7 @@ func Test_ErrorWhenUnableToInitDatabase(t *testing.T) { return jarFile, true } - database.initDatabase = func(binaryExtractLocation, username, password, locale string, logger io.Writer) error { + database.initDatabase = func(binaryExtractLocation, dataLocation, username, password, locale string, logger io.Writer) error { return errors.New("ah it did not work") } @@ -221,7 +221,7 @@ func Test_ErrorWhenCannotStartPostgresProcess(t *testing.T) { return jarFile, true } - database.initDatabase = func(binaryExtractLocation, username, password, locale string, logger io.Writer) error { + database.initDatabase = func(binaryExtractLocation, dataLocation, username, password, locale string, logger io.Writer) error { return nil } @@ -345,3 +345,82 @@ func Test_CanStartAndStopTwice(t *testing.T) { shutdownDBAndFail(t, err, database) } } + +//nolint:funlen +func Test_ReuseData(t *testing.T) { + tempDir, err := ioutil.TempDir("", "embedded_postgres_test") + if err != nil { + panic(err) + } + + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + panic(err) + } + }() + + database := NewDatabase(DefaultConfig().DataPath(tempDir)) + + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")) + if err != nil { + shutdownDBAndFail(t, err, database) + } + + if _, err = db.Exec("CREATE TABLE test(id serial, value text, PRIMARY KEY(id))"); err != nil { + shutdownDBAndFail(t, err, database) + } + + if _, err = db.Exec("INSERT INTO test (value) VALUES ('foobar')"); err != nil { + shutdownDBAndFail(t, err, database) + } + + if err := db.Close(); err != nil { + shutdownDBAndFail(t, err, database) + } + + if err := database.Stop(); err != nil { + shutdownDBAndFail(t, err, database) + } + + database = NewDatabase(DefaultConfig().DataPath(tempDir)) + + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + db, err = sql.Open("postgres", fmt.Sprintf("host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable")) + if err != nil { + shutdownDBAndFail(t, err, database) + } + + if rows, err := db.Query("SELECT * FROM test"); err != nil { + shutdownDBAndFail(t, err, database) + } else { + if !rows.Next() { + shutdownDBAndFail(t, errors.New("no row from db"), database) + } + + var ( + id int64 + value string + ) + if err := rows.Scan(&id, &value); err != nil { + shutdownDBAndFail(t, err, database) + } + if value != "foobar" { + shutdownDBAndFail(t, errors.New("wrong value from db"), database) + } + } + + if err := db.Close(); err != nil { + shutdownDBAndFail(t, err, database) + } + + if err := database.Stop(); err != nil { + shutdownDBAndFail(t, err, database) + } +} diff --git a/platform-test/platform_test.go b/platform-test/platform_test.go index fe2ac49..ea5fdb3 100644 --- a/platform-test/platform_test.go +++ b/platform-test/platform_test.go @@ -7,9 +7,10 @@ import ( "os" "path/filepath" "strconv" + "strings" "testing" - "github.com/fergusstrange/embedded-postgres" + embeddedpostgres "github.com/fergusstrange/embedded-postgres" ) func Test_AllMajorVersions(t *testing.T) { @@ -28,10 +29,11 @@ func Test_AllMajorVersions(t *testing.T) { for testNumber, version := range allVersions { t.Run(fmt.Sprintf("MajorVersion_%d", testNumber), func(t *testing.T) { port := uint32(5555 + testNumber) + runtimePath := filepath.Join(tempExtractLocation, strconv.Itoa(testNumber)) database := embeddedpostgres.NewDatabase(embeddedpostgres.DefaultConfig(). Version(version). Port(port). - RuntimePath(filepath.Join(tempExtractLocation, strconv.Itoa(testNumber)))) + RuntimePath(runtimePath)) if err := database.Start(); err != nil { shutdownDBAndFail(t, err, database, version) @@ -57,6 +59,10 @@ func Test_AllMajorVersions(t *testing.T) { if err := database.Stop(); err != nil { t.Fatal(err) } + + if err := checkPgVersionFile(filepath.Join(runtimePath, "data"), version); err != nil { + t.Fatal(err) + } }) } if err := os.RemoveAll(tempExtractLocation); err != nil { @@ -75,3 +81,20 @@ func connect(port uint32) (*sql.DB, error) { db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d user=postgres password=postgres dbname=postgres sslmode=disable", port)) return db, err } + +func checkPgVersionFile(dataDir string, version embeddedpostgres.PostgresVersion) error { + pgVersion := filepath.Join(dataDir, "PG_VERSION") + + d, err := ioutil.ReadFile(pgVersion) + if err != nil { + return fmt.Errorf("could not read file %v", pgVersion) + } + + v := strings.TrimSuffix(string(d), "\n") + + if strings.HasPrefix(string(version), v) { + return nil + } + + return fmt.Errorf("version missmatch in PG_VERSION: %v <> %v", string(version), v) +} diff --git a/prepare_database.go b/prepare_database.go index 3287422..b6a2d5e 100644 --- a/prepare_database.go +++ b/prepare_database.go @@ -14,10 +14,10 @@ import ( "github.com/lib/pq" ) -type initDatabase func(binaryExtractLocation, username, password, locale string, logger io.Writer) error +type initDatabase func(binaryExtractLocation, pgDataDir, username, password, locale string, logger io.Writer) error type createDatabase func(port uint32, username, password, database string) error -func defaultInitDatabase(binaryExtractLocation, username, password, locale string, logger io.Writer) error { +func defaultInitDatabase(binaryExtractLocation, pgDataDir, username, password, locale string, logger io.Writer) error { passwordFile, err := createPasswordFile(binaryExtractLocation, password) if err != nil { return err @@ -26,7 +26,7 @@ func defaultInitDatabase(binaryExtractLocation, username, password, locale strin args := []string{ "-A", "password", "-U", username, - "-D", filepath.Join(binaryExtractLocation, "data"), + "-D", pgDataDir, fmt.Sprintf("--pwfile=%s", passwordFile), } diff --git a/prepare_database_test.go b/prepare_database_test.go index c06f919..170898f 100644 --- a/prepare_database_test.go +++ b/prepare_database_test.go @@ -11,7 +11,7 @@ import ( ) func Test_defaultInitDatabase_ErrorWhenCannotCreatePasswordFile(t *testing.T) { - err := defaultInitDatabase("path_not_exists", "Tom", "Beer", "", os.Stderr) + err := defaultInitDatabase("path_not_exists", "path_not_exists", "Tom", "Beer", "", os.Stderr) assert.EqualError(t, err, "unable to write password file to path_not_exists/pwfile") } @@ -28,7 +28,7 @@ func Test_defaultInitDatabase_ErrorWhenCannotStartInitDBProcess(t *testing.T) { } }() - err = defaultInitDatabase(tempDir, "Tom", "Beer", "", os.Stderr) + err = defaultInitDatabase(tempDir, filepath.Join(tempDir, "data"), "Tom", "Beer", "", os.Stderr) assert.EqualError(t, err, fmt.Sprintf("unable to init database using: %s/bin/initdb -A password -U Tom -D %s/data --pwfile=%s/pwfile", tempDir, @@ -49,7 +49,7 @@ func Test_defaultInitDatabase_ErrorInvalidLocaleSetting(t *testing.T) { } }() - err = defaultInitDatabase(tempDir, "postgres", "postgres", "en_XY", os.Stderr) + err = defaultInitDatabase(tempDir, filepath.Join(tempDir, "data"), "postgres", "postgres", "en_XY", os.Stderr) assert.EqualError(t, err, fmt.Sprintf("unable to init database using: %s/bin/initdb -A password -U postgres -D %s/data --pwfile=%s/pwfile --locale=en_XY", tempDir,