diff --git a/postgresql/resource_postgresql_script.go b/postgresql/resource_postgresql_script.go index af3348ee..1d10cfca 100644 --- a/postgresql/resource_postgresql_script.go +++ b/postgresql/resource_postgresql_script.go @@ -14,6 +14,7 @@ import ( const ( scriptCommandsAttr = "commands" + scriptDatabaseAttr = "database" scriptTriesAttr = "tries" scriptBackoffDelayAttr = "backoff_delay" scriptTimeoutAttr = "timeout" @@ -28,6 +29,13 @@ func resourcePostgreSQLScript() *schema.Resource { Delete: PGResourceFunc(resourcePostgreSQLScriptDelete), Schema: map[string]*schema.Schema{ + scriptDatabaseAttr: { + Type: schema.TypeString, + Optional: true, + Computed: true, + ForceNew: true, + Description: "The database to execute commands in (defaults to provider's configured database)", + }, scriptCommandsAttr: { Type: schema.TypeList, Required: true, @@ -77,9 +85,20 @@ func resourcePostgreSQLScriptCreateOrUpdate(ctx context.Context, db *DBConnectio }} } - sum := shasumCommands(commands) + // Get the target database connection + database := getDatabaseAttrOrDefault(d, db.client.databaseName) + + client := db.client.config.NewClient(database) + newDB, err := client.Connect() + if err != nil { + return diag.Diagnostics{diag.Diagnostic{ + Severity: diag.Error, + Summary: "Failed to connect to database", + Detail: err.Error(), + }} + } - if err := executeCommands(ctx, db, commands, tries, backoffDelay, timeout); err != nil { + if err := executeCommands(ctx, newDB, commands, tries, backoffDelay, timeout); err != nil { return diag.Diagnostics{diag.Diagnostic{ Severity: diag.Error, Summary: "Commands execution failed", @@ -87,18 +106,43 @@ func resourcePostgreSQLScriptCreateOrUpdate(ctx context.Context, db *DBConnectio }} } - d.Set(scriptShasumAttr, sum) + sum := shasumCommands(commands) d.SetId(sum) + + if err := resourcePostgreSQLScriptReadImpl(db, d); err != nil { + return diag.Diagnostics{diag.Diagnostic{ + Severity: diag.Error, + Summary: "Failed to read script state", + Detail: err.Error(), + }} + } + return nil } +func getDatabaseAttrOrDefault(d *schema.ResourceData, databaseName string) string { + if v, ok := d.GetOk(scriptDatabaseAttr); ok { + databaseName = v.(string) + } + + return databaseName +} + func resourcePostgreSQLScriptRead(db *DBConnection, d *schema.ResourceData) error { + return resourcePostgreSQLScriptReadImpl(db, d) +} + +func resourcePostgreSQLScriptReadImpl(db *DBConnection, d *schema.ResourceData) error { commands, err := toStringArray(d.Get(scriptCommandsAttr).([]any)) if err != nil { return err } newSum := shasumCommands(commands) + + database := getDatabaseAttrOrDefault(d, db.client.databaseName) + d.Set(scriptShasumAttr, newSum) + d.Set(scriptDatabaseAttr, database) return nil } diff --git a/postgresql/resource_postgresql_script_test.go b/postgresql/resource_postgresql_script_test.go index 8e624237..4edf1753 100644 --- a/postgresql/resource_postgresql_script_test.go +++ b/postgresql/resource_postgresql_script_test.go @@ -1,10 +1,12 @@ package postgresql import ( + "fmt" "regexp" "testing" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" + "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" ) func TestAccPostgresqlScript_basic(t *testing.T) { @@ -227,3 +229,126 @@ func TestAccPostgresqlScript_timeout(t *testing.T) { }, }) } + +func TestAccPostgresqlScript_withDatabase(t *testing.T) { + config := ` + resource "postgresql_database" "test_db" { + name = "test_script_db" + } + + resource "postgresql_script" "test" { + database = postgresql_database.test_db.name + commands = [ + "CREATE TABLE test_table (id INT);", + "INSERT INTO test_table VALUES (1);" + ] + depends_on = [postgresql_database.test_db] + } + + resource "postgresql_script" "test_default" { + commands = [ + "CREATE TABLE default_db_table (id INT);", + "INSERT INTO default_db_table VALUES (1);", + "INSERT INTO default_db_table VALUES (2);" + ] + depends_on = [postgresql_database.test_db] + } + ` + + resource.Test(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + CheckDestroy: testAccCheckScriptTablesDestroyed, + Steps: []resource.TestStep{ + { + Config: config, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("postgresql_script.test", "database", "test_script_db"), + resource.TestCheckResourceAttr("postgresql_script.test", "commands.0", "CREATE TABLE test_table (id INT);"), + resource.TestCheckResourceAttr("postgresql_script.test", "commands.1", "INSERT INTO test_table VALUES (1);"), + resource.TestCheckResourceAttr("postgresql_script.test_default", "database", "postgres"), + resource.TestCheckResourceAttr("postgresql_script.test_default", "commands.0", "CREATE TABLE default_db_table (id INT);"), + resource.TestCheckResourceAttr("postgresql_script.test_default", "commands.1", "INSERT INTO default_db_table VALUES (1);"), + resource.TestCheckResourceAttr("postgresql_script.test_default", "commands.2", "INSERT INTO default_db_table VALUES (2);"), + testAccCheckTableExistsInDatabase("test_script_db", "test_table"), + testAccCheckTableHasRecords("test_script_db", "test_table", 1), + testAccCheckTableExistsInDatabase("postgres", "default_db_table"), + testAccCheckTableHasRecords("postgres", "default_db_table", 2), + ), + }, + }, + }) +} + +func testAccCheckScriptTablesDestroyed(s *terraform.State) error { + return testAccDropTables(map[string][]string{ + "test_script_db": {"test_table"}, + "postgres": {"default_db_table"}, + }) +} + +func testAccDropTables(tablesToDrop map[string][]string) error { + client := testAccProvider.Meta().(*Client) + + for dbName, tables := range tablesToDrop { + dbClient := client.config.NewClient(dbName) + db, err := dbClient.Connect() + if err != nil { + continue // Skip if we can't connect to the database + } + + for _, tableName := range tables { + _, _ = db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)) + } + } + + return nil +} + +func testAccCheckTableExistsInDatabase(dbName, tableName string) resource.TestCheckFunc { + return func(s *terraform.State) error { + client := testAccProvider.Meta().(*Client) + dbClient := client.config.NewClient(dbName) + db, err := dbClient.Connect() + if err != nil { + return fmt.Errorf("Error connecting to database %s: %s", dbName, err) + } + + var exists bool + query := "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1)" + err = db.QueryRow(query, tableName).Scan(&exists) + if err != nil { + return fmt.Errorf("Error checking if table %s exists: %s", tableName, err) + } + + if !exists { + return fmt.Errorf("Table %s does not exist in database %s", tableName, dbName) + } + + return nil + } +} + +func testAccCheckTableHasRecords(dbName, tableName string, expectedCount int) resource.TestCheckFunc { + return func(s *terraform.State) error { + client := testAccProvider.Meta().(*Client) + dbClient := client.config.NewClient(dbName) + db, err := dbClient.Connect() + if err != nil { + return fmt.Errorf("Error connecting to database %s: %s", dbName, err) + } + + var count int + query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName) + err = db.QueryRow(query).Scan(&count) + if err != nil { + return fmt.Errorf("Error counting records in table %s: %s", tableName, err) + } + + if count != expectedCount { + return fmt.Errorf("Expected %d records but got %d in table %s", expectedCount, count, tableName) + } + + return nil + } +}