Skip to content

Commit

Permalink
make persistence and path of db data directory configurable (#19)
Browse files Browse the repository at this point in the history
* add support for persistent database data

- config: add field and setter for db data path
- do not delete/reinit db data dir if path is set
- data dir must exist and the pg (major) version
  must match used postgres version
- tests: adjust existing test to new initDatabase function signature
- tests: add test for resuse data option
- platform test: test PG_VERSION file to match requested version
- add doc about DataPath config option
  • Loading branch information
bearsh committed Mar 2, 2021
1 parent b9e72af commit e846ac4
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 30 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@ This library aims to require as little configuration as possible, favouring over
| Database | postgres |
| Version | 12.1.0 |
| RuntimePath | $USER_HOME/.embedded-postgres-go/extracted |
| DataPath | empty, results in *RuntimePath*/data |
| Port | 5432 |
| StartTimeout | 15 Seconds |

The *RuntimePath* directory is erased and recreated on each start and therefor not
suitable for persistent data. If a persistent data location is required, set
*DataPath* to a directory outside of *RuntimePath*.
If the *RuntimePath* directory is empty or already initialized but with an incompatible
postgres version, it gets cleared and reinitialized.

A single Postgres instance can be created, started and stopped as follows
```go
postgres := embeddedpostgres.NewDatabase()
Expand Down
11 changes: 10 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Config struct {
username string
password string
runtimePath string
dataPath string
locale string
startTimeout time.Duration
logger io.Writer
Expand Down Expand Up @@ -69,12 +70,20 @@ func (c Config) Password(password string) Config {
return c
}

// RuntimePath sets the path that will be used for the extracted Postgres runtime and data directory.
// RuntimePath sets the path that will be used for the extracted Postgres runtime directory.
// If Postgres data directory is not set with DataPath(), this directory is also used as data directory.
func (c Config) RuntimePath(path string) Config {
c.runtimePath = path
return c
}

// DataPath sets the path that will be used for the Postgres data directory.
// If this option is set, a previously initialized data directory will be reused if possible.
func (c Config) DataPath(path string) Config {
c.dataPath = path
return c
}

// Locale sets the default locale for initdb
func (c Config) Locale(locale string) Config {
c.locale = locale
Expand Down
73 changes: 54 additions & 19 deletions embedded_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package embeddedpostgres
import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"os/exec"
"path/filepath"
"strings"

"github.com/mholt/archiver/v3"
)
Expand Down Expand Up @@ -51,6 +52,7 @@ func newDatabaseWithConfig(config Config) *EmbeddedPostgres {

// Start will try to start the configured Postgres process returning an error when there were any problems with invocation.
// If any error occurs Start will try to also Stop the Postgres process in order to not leave any sub-process running.
//nolint:funlen
func (ep *EmbeddedPostgres) Start() error {
if ep.started {
return errors.New("server is already started")
Expand All @@ -67,17 +69,27 @@ func (ep *EmbeddedPostgres) Start() error {
}
}

binaryExtractLocation := userLocationOrDefault(ep.config.runtimePath, cacheLocation)
binaryExtractLocation := userRuntimePathOrDefault(ep.config.runtimePath, cacheLocation)
if err := os.RemoveAll(binaryExtractLocation); err != nil {
return fmt.Errorf("unable to clean up directory %s with error: %s", binaryExtractLocation, err)
return fmt.Errorf("unable to clean up runtime directory %s with error: %s", binaryExtractLocation, err)
}

if err := archiver.NewTarXz().Unarchive(cacheLocation, binaryExtractLocation); err != nil {
return fmt.Errorf("unable to extract postgres archive %s to %s", cacheLocation, binaryExtractLocation)
}

if err := ep.initDatabase(binaryExtractLocation, ep.config.username, ep.config.password, ep.config.locale, ep.config.logger); err != nil {
return err
dataLocation := userDataPathOrDefault(ep.config.dataPath, binaryExtractLocation)

reuseData := ep.config.dataPath != "" && dataDirIsValid(dataLocation, ep.config.version)

if !reuseData {
if err := os.RemoveAll(dataLocation); err != nil {
return fmt.Errorf("unable to clean up data directory %s with error: %s", dataLocation, err)
}

if err := ep.initDatabase(binaryExtractLocation, dataLocation, ep.config.username, ep.config.password, ep.config.locale, ep.config.logger); err != nil {
return err
}
}

if err := startPostgres(binaryExtractLocation, ep.config); err != nil {
Expand All @@ -86,16 +98,18 @@ func (ep *EmbeddedPostgres) Start() error {

ep.started = true

if err := ep.createDatabase(ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil {
if stopErr := stopPostgres(binaryExtractLocation, ep.config.logger); stopErr != nil {
return fmt.Errorf("unable to stop database casused by error %s", err)
}
if !reuseData {
if err := ep.createDatabase(ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil {
if stopErr := stopPostgres(binaryExtractLocation, ep.config); stopErr != nil {
return fmt.Errorf("unable to stop database casused by error %s", err)
}

return err
return err
}
}

if err := healthCheckDatabaseOrTimeout(ep.config); err != nil {
if stopErr := stopPostgres(binaryExtractLocation, ep.config.logger); stopErr != nil {
if stopErr := stopPostgres(binaryExtractLocation, ep.config); stopErr != nil {
return fmt.Errorf("unable to stop database casused by error %s", err)
}

Expand All @@ -112,8 +126,8 @@ func (ep *EmbeddedPostgres) Stop() error {
return errors.New("server has not been started")
}

binaryExtractLocation := userLocationOrDefault(ep.config.runtimePath, cacheLocation)
if err := stopPostgres(binaryExtractLocation, ep.config.logger); err != nil {
binaryExtractLocation := userRuntimePathOrDefault(ep.config.runtimePath, cacheLocation)
if err := stopPostgres(binaryExtractLocation, ep.config); err != nil {
return err
}

Expand All @@ -125,7 +139,7 @@ func (ep *EmbeddedPostgres) Stop() error {
func startPostgres(binaryExtractLocation string, config Config) error {
postgresBinary := filepath.Join(binaryExtractLocation, "bin/pg_ctl")
postgresProcess := exec.Command(postgresBinary, "start", "-w",
"-D", filepath.Join(binaryExtractLocation, "data"),
"-D", userDataPathOrDefault(config.dataPath, binaryExtractLocation),
"-o", fmt.Sprintf(`"-p %d"`, config.port))
log.Println(postgresProcess.String())
postgresProcess.Stderr = config.logger
Expand All @@ -138,12 +152,12 @@ func startPostgres(binaryExtractLocation string, config Config) error {
return nil
}

func stopPostgres(binaryExtractLocation string, logger io.Writer) error {
func stopPostgres(binaryExtractLocation string, config Config) error {
postgresBinary := filepath.Join(binaryExtractLocation, "bin/pg_ctl")
postgresProcess := exec.Command(postgresBinary, "stop", "-w",
"-D", filepath.Join(binaryExtractLocation, "data"))
postgresProcess.Stderr = logger
postgresProcess.Stdout = logger
"-D", userDataPathOrDefault(config.dataPath, binaryExtractLocation))
postgresProcess.Stderr = config.logger
postgresProcess.Stdout = config.logger

return postgresProcess.Run()
}
Expand All @@ -161,10 +175,31 @@ func ensurePortAvailable(port uint32) error {
return nil
}

func userLocationOrDefault(userLocation, cacheLocation string) string {
func userRuntimePathOrDefault(userLocation, cacheLocation string) string {
if userLocation != "" {
return userLocation
}

return filepath.Join(filepath.Dir(cacheLocation), "extracted")
}

func userDataPathOrDefault(userLocation, runtimeLocation string) string {
if userLocation != "" {
return userLocation
}

return filepath.Join(runtimeLocation, "data")
}

func dataDirIsValid(dataDir string, version PostgresVersion) bool {
pgVersion := filepath.Join(dataDir, "PG_VERSION")

d, err := ioutil.ReadFile(pgVersion)
if err != nil {
return false
}

v := strings.TrimSuffix(string(d), "\n")

return strings.HasPrefix(string(version), v)
}
83 changes: 81 additions & 2 deletions embedded_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func Test_ErrorWhenUnableToInitDatabase(t *testing.T) {
return jarFile, true
}

database.initDatabase = func(binaryExtractLocation, username, password, locale string, logger io.Writer) error {
database.initDatabase = func(binaryExtractLocation, dataLocation, username, password, locale string, logger io.Writer) error {
return errors.New("ah it did not work")
}

Expand Down Expand Up @@ -221,7 +221,7 @@ func Test_ErrorWhenCannotStartPostgresProcess(t *testing.T) {
return jarFile, true
}

database.initDatabase = func(binaryExtractLocation, username, password, locale string, logger io.Writer) error {
database.initDatabase = func(binaryExtractLocation, dataLocation, username, password, locale string, logger io.Writer) error {
return nil
}

Expand Down Expand Up @@ -345,3 +345,82 @@ func Test_CanStartAndStopTwice(t *testing.T) {
shutdownDBAndFail(t, err, database)
}
}

//nolint:funlen
func Test_ReuseData(t *testing.T) {
tempDir, err := ioutil.TempDir("", "embedded_postgres_test")
if err != nil {
panic(err)
}

defer func() {
if err := os.RemoveAll(tempDir); err != nil {
panic(err)
}
}()

database := NewDatabase(DefaultConfig().DataPath(tempDir))

if err := database.Start(); err != nil {
shutdownDBAndFail(t, err, database)
}

db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable"))
if err != nil {
shutdownDBAndFail(t, err, database)
}

if _, err = db.Exec("CREATE TABLE test(id serial, value text, PRIMARY KEY(id))"); err != nil {
shutdownDBAndFail(t, err, database)
}

if _, err = db.Exec("INSERT INTO test (value) VALUES ('foobar')"); err != nil {
shutdownDBAndFail(t, err, database)
}

if err := db.Close(); err != nil {
shutdownDBAndFail(t, err, database)
}

if err := database.Stop(); err != nil {
shutdownDBAndFail(t, err, database)
}

database = NewDatabase(DefaultConfig().DataPath(tempDir))

if err := database.Start(); err != nil {
shutdownDBAndFail(t, err, database)
}

db, err = sql.Open("postgres", fmt.Sprintf("host=localhost port=5432 user=postgres password=postgres dbname=postgres sslmode=disable"))
if err != nil {
shutdownDBAndFail(t, err, database)
}

if rows, err := db.Query("SELECT * FROM test"); err != nil {
shutdownDBAndFail(t, err, database)
} else {
if !rows.Next() {
shutdownDBAndFail(t, errors.New("no row from db"), database)
}

var (
id int64
value string
)
if err := rows.Scan(&id, &value); err != nil {
shutdownDBAndFail(t, err, database)
}
if value != "foobar" {
shutdownDBAndFail(t, errors.New("wrong value from db"), database)
}
}

if err := db.Close(); err != nil {
shutdownDBAndFail(t, err, database)
}

if err := database.Stop(); err != nil {
shutdownDBAndFail(t, err, database)
}
}
27 changes: 25 additions & 2 deletions platform-test/platform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import (
"os"
"path/filepath"
"strconv"
"strings"
"testing"

"github.com/fergusstrange/embedded-postgres"
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
)

func Test_AllMajorVersions(t *testing.T) {
Expand All @@ -28,10 +29,11 @@ func Test_AllMajorVersions(t *testing.T) {
for testNumber, version := range allVersions {
t.Run(fmt.Sprintf("MajorVersion_%d", testNumber), func(t *testing.T) {
port := uint32(5555 + testNumber)
runtimePath := filepath.Join(tempExtractLocation, strconv.Itoa(testNumber))
database := embeddedpostgres.NewDatabase(embeddedpostgres.DefaultConfig().
Version(version).
Port(port).
RuntimePath(filepath.Join(tempExtractLocation, strconv.Itoa(testNumber))))
RuntimePath(runtimePath))

if err := database.Start(); err != nil {
shutdownDBAndFail(t, err, database, version)
Expand All @@ -57,6 +59,10 @@ func Test_AllMajorVersions(t *testing.T) {
if err := database.Stop(); err != nil {
t.Fatal(err)
}

if err := checkPgVersionFile(filepath.Join(runtimePath, "data"), version); err != nil {
t.Fatal(err)
}
})
}
if err := os.RemoveAll(tempExtractLocation); err != nil {
Expand All @@ -75,3 +81,20 @@ func connect(port uint32) (*sql.DB, error) {
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d user=postgres password=postgres dbname=postgres sslmode=disable", port))
return db, err
}

func checkPgVersionFile(dataDir string, version embeddedpostgres.PostgresVersion) error {
pgVersion := filepath.Join(dataDir, "PG_VERSION")

d, err := ioutil.ReadFile(pgVersion)
if err != nil {
return fmt.Errorf("could not read file %v", pgVersion)
}

v := strings.TrimSuffix(string(d), "\n")

if strings.HasPrefix(string(version), v) {
return nil
}

return fmt.Errorf("version missmatch in PG_VERSION: %v <> %v", string(version), v)
}
6 changes: 3 additions & 3 deletions prepare_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (
"github.com/lib/pq"
)

type initDatabase func(binaryExtractLocation, username, password, locale string, logger io.Writer) error
type initDatabase func(binaryExtractLocation, pgDataDir, username, password, locale string, logger io.Writer) error
type createDatabase func(port uint32, username, password, database string) error

func defaultInitDatabase(binaryExtractLocation, username, password, locale string, logger io.Writer) error {
func defaultInitDatabase(binaryExtractLocation, pgDataDir, username, password, locale string, logger io.Writer) error {
passwordFile, err := createPasswordFile(binaryExtractLocation, password)
if err != nil {
return err
Expand All @@ -26,7 +26,7 @@ func defaultInitDatabase(binaryExtractLocation, username, password, locale strin
args := []string{
"-A", "password",
"-U", username,
"-D", filepath.Join(binaryExtractLocation, "data"),
"-D", pgDataDir,
fmt.Sprintf("--pwfile=%s", passwordFile),
}

Expand Down
Loading

0 comments on commit e846ac4

Please sign in to comment.