diff --git a/database/snowflake/README.md b/database/snowflake/README.md index 90a28d177..358abed6a 100644 --- a/database/snowflake/README.md +++ b/database/snowflake/README.md @@ -2,11 +2,14 @@ `snowflake://user:password@accountname/schema/dbname?query` -| URL Query | WithInstance Config | Description | -|------------|---------------------|-------------| -| `x-migrations-table` | `MigrationsTable` | Name of the migrations table | +| URL Query | WithInstance Config | Description | +| -------------------- | ------------------- | ---------------------------- | +| `x-migrations-table` | `MigrationsTable` | Name of the migrations table | +| `warehouse` | | Snowflake warehouse to use | -Snowflake is PostgreSQL compatible but has some specific features (or lack thereof) that require slightly different behavior. +Snowflake is PostgreSQL compatible but has some specific features (or lack +thereof) that require slightly different behavior. ## Status + This driver is not officially supported as there are no tests for it. diff --git a/database/snowflake/snowflake.go b/database/snowflake/snowflake.go index 46ce30200..32714508f 100644 --- a/database/snowflake/snowflake.go +++ b/database/snowflake/snowflake.go @@ -118,12 +118,15 @@ func (p *Snowflake) Open(url string) (database.Driver, error) { return nil, ErrNoSchema } + warehouse := purl.Query().Get("warehouse") + cfg := &sf.Config{ - Account: purl.Host, - User: purl.User.Username(), - Password: password, - Database: database, - Schema: schema, + Account: purl.Host, + User: purl.User.Username(), + Password: password, + Database: database, + Schema: schema, + Warehouse: warehouse, } dsn, err := sf.DSN(cfg) @@ -180,7 +183,12 @@ func (p *Snowflake) Run(migration io.Reader) error { // run migration query := string(migr[:]) - if _, err := p.conn.ExecContext(context.Background(), query); err != nil { + stmtCount := countStatements(query) + context, err := sf.WithMultiStatement(context.Background(), stmtCount) + if err != nil { + return err + } + if _, err := p.conn.ExecContext(context, query); err != nil { if pgErr, ok := err.(*pq.Error); ok { var line uint var col uint @@ -205,6 +213,11 @@ func (p *Snowflake) Run(migration io.Reader) error { return nil } +func countStatements(query string) int { + semicolonCount := strings.Count(query, ";") + return semicolonCount +} + func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { // replace crlf with lf s = strings.Replace(s, "\r\n", "\n", -1)