Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion postgres_pgx/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ type TestConfig struct {
RunContainer bool
ContainerImage string
ContainerName string
MacConnections int
}

func NewTestPlugin(ctx context.Context, cfg TestConfig) (Plugin, error) {
Expand Down Expand Up @@ -135,6 +136,11 @@ func NewTestPlugin(ctx context.Context, cfg TestConfig) (Plugin, error) {
WithOccurrence(2).
WithStartupTimeout(5 * time.Second)

cmd := "-N 500"
if cfg.MacConnections > 0 {
cmd = fmt.Sprintf("-N %d", cfg.MacConnections)
}

container, err := postgres.RunContainer(ctx,
testcontainers.WithImage(opts.ContainerImage),
postgres.WithDatabase("postgres"),
Expand All @@ -144,6 +150,7 @@ func NewTestPlugin(ctx context.Context, cfg TestConfig) (Plugin, error) {
testcontainers.CustomizeRequest(testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Name: opts.ContainerName,
Cmd: []string{cmd},
},
Reuse: true,
}),
Expand All @@ -161,7 +168,39 @@ func NewTestPlugin(ctx context.Context, cfg TestConfig) (Plugin, error) {
return nil, err
}

opts.DSN = fmt.Sprintf("postgres://user:pass@%v:%v/postgres?sslmode=disable", host, realPort.Port())
dbURL := fmt.Sprintf("postgres://user:pass@%v:%v/postgres?sslmode=disable", host, realPort.Port())

conn, err := pgx.Connect(ctx, dbURL)
if err != nil {
log.Fatalf("failed to connect to database: %v", err)
}

if opts.Database == "" {
opts.Database = "postgres"
}

exists, err := databaseExists(ctx, conn, opts.Database)
if err != nil {
return nil, fmt.Errorf("failed to check if database exists: %v", err)
}

if !exists {
createDBQuery := fmt.Sprintf("CREATE DATABASE %s", opts.Database)
_, err = conn.Exec(ctx, createDBQuery)
if err != nil {
return nil, fmt.Errorf("failed to create database: %v", err)
}
fmt.Printf("Database %s created successfully\n", opts.Database)
} else {
fmt.Printf("Database %s already exists, continuing...\n", opts.Database)
}
_ = conn.Close(ctx)

if opts.Database != "postgres" {
opts.DSN = fmt.Sprintf("postgres://user:pass@%v:%v/%s?sslmode=disable", host, realPort.Port(), opts.Database)
} else {
opts.DSN = dbURL
}
}

p := new(plugin)
Expand Down Expand Up @@ -304,3 +343,13 @@ func (p *plugin) initPlugin(ctx context.Context) error {

return nil
}

func databaseExists(ctx context.Context, conn *pgx.Conn, dbName string) (bool, error) {
var exists bool
query := fmt.Sprintf("SELECT EXISTS(SELECT datname FROM pg_catalog.pg_database WHERE datname = '%s')", dbName)
err := conn.QueryRow(ctx, query).Scan(&exists)
if err != nil {
return false, err
}
return exists, nil
}