Skip to content
50 changes: 47 additions & 3 deletions postgresql/resource_postgresql_script.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

const (
scriptCommandsAttr = "commands"
scriptDatabaseAttr = "database"
scriptTriesAttr = "tries"
scriptBackoffDelayAttr = "backoff_delay"
scriptTimeoutAttr = "timeout"
Expand All @@ -28,6 +29,13 @@ func resourcePostgreSQLScript() *schema.Resource {
Delete: PGResourceFunc(resourcePostgreSQLScriptDelete),

Schema: map[string]*schema.Schema{
scriptDatabaseAttr: {
Type: schema.TypeString,
Optional: true,
Computed: true,
ForceNew: true,
Description: "The database to execute commands in (defaults to provider's configured database)",
},
scriptCommandsAttr: {
Type: schema.TypeList,
Required: true,
Expand Down Expand Up @@ -77,28 +85,64 @@ func resourcePostgreSQLScriptCreateOrUpdate(ctx context.Context, db *DBConnectio
}}
}

sum := shasumCommands(commands)
// Get the target database connection
database := getDatabaseAttrOrDefault(d, db.client.databaseName)

client := db.client.config.NewClient(database)
newDB, err := client.Connect()
if err != nil {
return diag.Diagnostics{diag.Diagnostic{
Severity: diag.Error,
Summary: "Failed to connect to database",
Detail: err.Error(),
}}
}

if err := executeCommands(ctx, db, commands, tries, backoffDelay, timeout); err != nil {
if err := executeCommands(ctx, newDB, commands, tries, backoffDelay, timeout); err != nil {
return diag.Diagnostics{diag.Diagnostic{
Severity: diag.Error,
Summary: "Commands execution failed",
Detail: err.Error(),
}}
}

d.Set(scriptShasumAttr, sum)
sum := shasumCommands(commands)
d.SetId(sum)

if err := resourcePostgreSQLScriptReadImpl(db, d); err != nil {
return diag.Diagnostics{diag.Diagnostic{
Severity: diag.Error,
Summary: "Failed to read script state",
Detail: err.Error(),
}}
}

return nil
}

func getDatabaseAttrOrDefault(d *schema.ResourceData, databaseName string) string {
if v, ok := d.GetOk(scriptDatabaseAttr); ok {
databaseName = v.(string)
}

return databaseName
}

func resourcePostgreSQLScriptRead(db *DBConnection, d *schema.ResourceData) error {
return resourcePostgreSQLScriptReadImpl(db, d)
}

func resourcePostgreSQLScriptReadImpl(db *DBConnection, d *schema.ResourceData) error {
commands, err := toStringArray(d.Get(scriptCommandsAttr).([]any))
if err != nil {
return err
}
newSum := shasumCommands(commands)

database := getDatabaseAttrOrDefault(d, db.client.databaseName)

d.Set(scriptShasumAttr, newSum)
d.Set(scriptDatabaseAttr, database)

return nil
}
Expand Down
125 changes: 125 additions & 0 deletions postgresql/resource_postgresql_script_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package postgresql

import (
"fmt"
"regexp"
"testing"

"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/terraform"
)

func TestAccPostgresqlScript_basic(t *testing.T) {
Expand Down Expand Up @@ -227,3 +229,126 @@ func TestAccPostgresqlScript_timeout(t *testing.T) {
},
})
}

func TestAccPostgresqlScript_withDatabase(t *testing.T) {
config := `
resource "postgresql_database" "test_db" {
name = "test_script_db"
}

resource "postgresql_script" "test" {
database = postgresql_database.test_db.name
commands = [
"CREATE TABLE test_table (id INT);",
"INSERT INTO test_table VALUES (1);"
]
depends_on = [postgresql_database.test_db]
}

resource "postgresql_script" "test_default" {
commands = [
"CREATE TABLE default_db_table (id INT);",
"INSERT INTO default_db_table VALUES (1);",
"INSERT INTO default_db_table VALUES (2);"
]
depends_on = [postgresql_database.test_db]
}
`

resource.Test(t, resource.TestCase{
PreCheck: func() { testAccPreCheck(t) },
Providers: testAccProviders,
CheckDestroy: testAccCheckScriptTablesDestroyed,
Steps: []resource.TestStep{
{
Config: config,
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr("postgresql_script.test", "database", "test_script_db"),
resource.TestCheckResourceAttr("postgresql_script.test", "commands.0", "CREATE TABLE test_table (id INT);"),
resource.TestCheckResourceAttr("postgresql_script.test", "commands.1", "INSERT INTO test_table VALUES (1);"),
resource.TestCheckResourceAttr("postgresql_script.test_default", "database", "postgres"),
resource.TestCheckResourceAttr("postgresql_script.test_default", "commands.0", "CREATE TABLE default_db_table (id INT);"),
resource.TestCheckResourceAttr("postgresql_script.test_default", "commands.1", "INSERT INTO default_db_table VALUES (1);"),
resource.TestCheckResourceAttr("postgresql_script.test_default", "commands.2", "INSERT INTO default_db_table VALUES (2);"),
testAccCheckTableExistsInDatabase("test_script_db", "test_table"),
testAccCheckTableHasRecords("test_script_db", "test_table", 1),
testAccCheckTableExistsInDatabase("postgres", "default_db_table"),
testAccCheckTableHasRecords("postgres", "default_db_table", 2),
),
},
},
})
}

func testAccCheckScriptTablesDestroyed(s *terraform.State) error {
return testAccDropTables(map[string][]string{
"test_script_db": {"test_table"},
"postgres": {"default_db_table"},
})
}

func testAccDropTables(tablesToDrop map[string][]string) error {
client := testAccProvider.Meta().(*Client)

for dbName, tables := range tablesToDrop {
dbClient := client.config.NewClient(dbName)
db, err := dbClient.Connect()
if err != nil {
continue // Skip if we can't connect to the database
}

for _, tableName := range tables {
_, _ = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName))
}
}

return nil
}

func testAccCheckTableExistsInDatabase(dbName, tableName string) resource.TestCheckFunc {
return func(s *terraform.State) error {
client := testAccProvider.Meta().(*Client)
dbClient := client.config.NewClient(dbName)
db, err := dbClient.Connect()
if err != nil {
return fmt.Errorf("Error connecting to database %s: %s", dbName, err)
}

var exists bool
query := "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1)"
err = db.QueryRow(query, tableName).Scan(&exists)
if err != nil {
return fmt.Errorf("Error checking if table %s exists: %s", tableName, err)
}

if !exists {
return fmt.Errorf("Table %s does not exist in database %s", tableName, dbName)
}

return nil
}
}

func testAccCheckTableHasRecords(dbName, tableName string, expectedCount int) resource.TestCheckFunc {
return func(s *terraform.State) error {
client := testAccProvider.Meta().(*Client)
dbClient := client.config.NewClient(dbName)
db, err := dbClient.Connect()
if err != nil {
return fmt.Errorf("Error connecting to database %s: %s", dbName, err)
}

var count int
query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName)
err = db.QueryRow(query).Scan(&count)
if err != nil {
return fmt.Errorf("Error counting records in table %s: %s", tableName, err)
}

if count != expectedCount {
return fmt.Errorf("Expected %d records but got %d in table %s", expectedCount, count, tableName)
}

return nil
}
}