diff --git a/testing/platform/pgtesting/postgres.go b/testing/platform/pgtesting/postgres.go index f8c19ca4..01c4746a 100644 --- a/testing/platform/pgtesting/postgres.go +++ b/testing/platform/pgtesting/postgres.go @@ -20,17 +20,20 @@ import ( "github.com/stretchr/testify/require" ) -type TestingT interface { +type T interface { require.TestingT Cleanup(func()) } type Database struct { - Url string + url string + t T + dbName string + rootUrl string } func (s *Database) ConnString() string { - return s.Url + return s.url } func (s *Database) ConnectionOptions() bunconnect.ConnectionOptions { @@ -39,6 +42,21 @@ func (s *Database) ConnectionOptions() bunconnect.ConnectionOptions { } } +func (s *Database) Delete() { + db, err := sql.Open("postgres", s.rootUrl) + require.NoError(s.t, err) + defer func() { + require.NoError(s.t, db.Close()) + }() + + _, err = db.ExecContext(sharedlogging.TestingContext(), fmt.Sprintf(`drop database if exists "%s"`, s.dbName)) + require.NoError(s.t, err) +} + +func (s *Database) Name() string { + return s.dbName +} + type PostgresServer struct { Port string Config Config @@ -73,7 +91,20 @@ func (s *PostgresServer) GetDatabaseDSN(databaseName string) string { s.Config.InitialUserPassword, s.GetHost(), s.Port, databaseName) } -func (s *PostgresServer) NewDatabase(t TestingT) *Database { +func (s *PostgresServer) setupDatabase(t T, name string) { + db, err := sql.Open("postgres", s.GetDatabaseDSN(name)) + require.NoError(t, err) + defer func() { + require.NoError(t, db.Close()) + }() + + for _, extension := range s.Config.Extensions { + _, err = db.ExecContext(sharedlogging.TestingContext(), fmt.Sprintf(`create extension "%s" schema public`, extension)) + require.NoError(t, err) + } +} + +func (s *PostgresServer) NewDatabase(t T) *Database { db, err := sql.Open("postgres", s.GetDSN()) require.NoError(t, err) defer func() { @@ -81,27 +112,23 @@ func (s *PostgresServer) NewDatabase(t TestingT) *Database { }() databaseName := uuid.NewString() - _, err = db.ExecContext(sharedlogging.TestingContext(), fmt.Sprintf(`CREATE DATABASE "%s"`, databaseName)) + _, err = db.ExecContext(sharedlogging.TestingContext(), fmt.Sprintf(`create database "%s"`, databaseName)) require.NoError(t, err) - if os.Getenv("NO_CLEANUP") != "true" { - t.Cleanup(func() { - db, err := sql.Open("postgres", s.GetDSN()) - require.NoError(t, err) - defer func() { - require.Nil(t, db.Close()) - }() + s.setupDatabase(t, databaseName) - _, err = db.ExecContext(sharedlogging.TestingContext(), fmt.Sprintf(`DROP DATABASE "%s"`, databaseName)) - if err != nil { - panic(err) - } - }) + ret := &Database{ + rootUrl: s.GetDSN(), + url: s.GetDatabaseDSN(databaseName), + t: t, + dbName: databaseName, } - return &Database{ - Url: s.GetDatabaseDSN(databaseName), + if os.Getenv("NO_CLEANUP") != "true" { + t.Cleanup(ret.Delete) } + + return ret } type Config struct { @@ -110,6 +137,7 @@ type Config struct { InitialUsername string StatusCheckInterval time.Duration MaximumWaitingTime time.Duration + Extensions []string } func (c Config) validate() error { @@ -128,41 +156,55 @@ func (c Config) validate() error { return nil } -type option func(opts *Config) +type Option func(opts *Config) -func WithInitialDatabaseName(name string) option { +func WithInitialDatabaseName(name string) Option { return func(opts *Config) { opts.InitialDatabaseName = name } } -func WithInitialUser(username, pwd string) option { +func WithInitialUser(username, pwd string) Option { return func(opts *Config) { opts.InitialUserPassword = pwd opts.InitialUsername = username } } -func WithStatusCheckInterval(d time.Duration) option { +func WithStatusCheckInterval(d time.Duration) Option { return func(opts *Config) { opts.StatusCheckInterval = d } } -func WithMaximumWaitingTime(d time.Duration) option { +func WithMaximumWaitingTime(d time.Duration) Option { return func(opts *Config) { opts.MaximumWaitingTime = d } } -var defaultOptions = []option{ +func WithExtension(extensions ...string) Option { + return func(opts *Config) { + opts.Extensions = append(opts.Extensions, extensions...) + } +} + +func WithPGStatsExtension() Option { + return WithExtension("pg_stat_statements") +} + +func WithPGCrypto() Option { + return WithExtension("pgcrypto") +} + +var defaultOptions = []Option{ WithStatusCheckInterval(200 * time.Millisecond), WithInitialUser("root", "root"), WithMaximumWaitingTime(time.Minute), WithInitialDatabaseName("formance"), } -func CreatePostgresServer(t TestingT, pool *docker.Pool, opts ...option) *PostgresServer { +func CreatePostgresServer(t T, pool *docker.Pool, opts ...Option) *PostgresServer { cfg := Config{} for _, opt := range append(defaultOptions, opts...) { opt(&cfg) @@ -184,6 +226,9 @@ func CreatePostgresServer(t TestingT, pool *docker.Pool, opts ...option) *Postgr "-c", "enable_partition_pruning=on", "-c", "enable_partitionwise_join=on", "-c", "enable_partitionwise_aggregate=on", + "-c", "shared_preload_libraries=auto_explain,pg_stat_statements", + "-c", "log_lock_waits=on", + "-c", "log_min_messages=info", }, }, CheckFn: func(ctx context.Context, resource *dockertest.Resource) error {