Skip to content

Commit

Permalink
Adding support for schema management in snowflake
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavcohesity committed Jun 3, 2020
1 parent 7236e82 commit b00a0cc
Show file tree
Hide file tree
Showing 5 changed files with 406 additions and 3 deletions.
10 changes: 10 additions & 0 deletions database/snowflake/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Snowflake

`snowflake://user:password@accountname/schema/dbname?query`

| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |

Snowflake is PostgreSQL compatible but has some specific features (or lack thereof) that require slightly different behavior.
Snowflake doesn't run locally hence there are no tests. The library works against hosted instances of snowflake.
371 changes: 371 additions & 0 deletions database/snowflake/snowflake.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
package snowflake

import (
"context"
"database/sql"
"fmt"
"io"
"io/ioutil"
nurl "net/url"
"strconv"
"strings"

"github.com/golang-migrate/migrate/v4/database"
"github.com/hashicorp/go-multierror"
"github.com/lib/pq"
sf "github.com/snowflakedb/gosnowflake"
)

func init() {
db := Snowflake{}
database.Register("snowflake", &db)
}

var DefaultMigrationsTable = "schema_migrations"

var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNoPassword = fmt.Errorf("no password")
ErrNoSchema = fmt.Errorf("no schema")
ErrNoSchemaOrDatabase = fmt.Errorf("no schema/database name")
)

type Config struct {
MigrationsTable string
DatabaseName string
}

type Snowflake struct {
isLocked bool
conn *sql.Conn
db *sql.DB

// Open and WithInstance need to guarantee that config is never nil
config *Config
}

func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}

if err := instance.Ping(); err != nil {
return nil, err
}

if config.DatabaseName == "" {
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}

if len(databaseName) == 0 {
return nil, ErrNoDatabaseName
}

config.DatabaseName = databaseName
}

if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}

conn, err := instance.Conn(context.Background())

if err != nil {
return nil, err
}

px := &Snowflake{
conn: conn,
db: instance,
config: config,
}

if err := px.ensureVersionTable(); err != nil {
return nil, err
}

return px, nil
}

func (p *Snowflake) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}

password, isPasswordSet := purl.User.Password()
if !isPasswordSet {
return nil, ErrNoPassword
}

splitPath := strings.Split(purl.Path, "/")
if len(splitPath) < 3 {
return nil, ErrNoSchemaOrDatabase
}

database := splitPath[2]
if len(database) == 0 {
return nil, ErrNoDatabaseName
}

schema := splitPath[1]
if len(schema) == 0 {
return nil, ErrNoSchema
}

cfg := &sf.Config{
Account: purl.Host,
User: purl.User.Username(),
Password: password,
Database: database,
Schema: schema,
}

dsn, err := sf.DSN(cfg)
if err != nil {
return nil, err
}

db, err := sql.Open("snowflake", dsn)
if err != nil {
return nil, err
}

migrationsTable := purl.Query().Get("x-migrations-table")

px, err := WithInstance(db, &Config{
DatabaseName: database,
MigrationsTable: migrationsTable,
})
if err != nil {
return nil, err
}

return px, nil
}

func (p *Snowflake) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
return nil
}

func (p *Snowflake) Lock() error {
if p.isLocked {
return database.ErrLocked
}
p.isLocked = true
return nil
}

func (p *Snowflake) Unlock() error {
p.isLocked = false
return nil
}

func (p *Snowflake) Run(migration io.Reader) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
}

// run migration
query := string(migr[:])
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
if pgErr, ok := err.(*pq.Error); ok {
var line uint
var col uint
var lineColOK bool
if pgErr.Position != "" {
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
line, col, lineColOK = computeLineFromPos(query, int(pos))
}
}
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}

return nil
}

func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
// replace crlf with lf
s = strings.Replace(s, "\r\n", "\n", -1)
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
runes := []rune(s)
if pos > len(runes) {
return 0, 0, false
}
sel := runes[:pos]
line = uint(runesCount(sel, newLine) + 1)
col = uint(pos - 1 - runesLastIndex(sel, newLine))
return line, col, true
}

const newLine = '\n'

func runesCount(input []rune, target rune) int {
var count int
for _, r := range input {
if r == target {
count++
}
}
return count
}

func runesLastIndex(input []rune, target rune) int {
for i := len(input) - 1; i >= 0; i-- {
if input[i] == target {
return i
}
}
return -1
}

func (p *Snowflake) SetVersion(version int, dirty bool) error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}

query := `DELETE FROM "` + p.config.MigrationsTable + `"`
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}

// Also re-write the schema version for nil dirty versions to prevent
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version,
dirty) VALUES (` + strconv.FormatInt(int64(version), 10) + `,
` + strconv.FormatBool(dirty) + `)`
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}

if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}

return nil
}

func (p *Snowflake) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil

case err != nil:
if e, ok := err.(*pq.Error); ok {
if e.Code.Name() == "undefined_table" {
return database.NilVersion, false, nil
}
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}

default:
return version, dirty, nil
}
}

func (p *Snowflake) Drop() (err error) {
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()

// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}

if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
}

return nil
}

// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Snowflake type.
func (p *Snowflake) ensureVersionTable() (err error) {
if err = p.Lock(); err != nil {
return err
}

defer func() {
if e := p.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()

// check if migration table exists
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}

// if not, create the empty migration table
query = `CREATE TABLE if not exists "` + p.config.MigrationsTable + `" (
version bigint not null primary key, dirty boolean not null)`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

return nil
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ require (
github.com/mattn/go-sqlite3 v1.10.0
github.com/nakagami/firebirdsql v0.0.0-20190310045651-3c02a58cfed8
github.com/neo4j/neo4j-go-driver v1.8.0-beta02
github.com/snowflakedb/gosnowflake v1.3.5
github.com/stretchr/testify v1.5.1
github.com/tidwall/pretty v0.0.0-20180105212114-65a9db5fad51 // indirect
github.com/xanzy/go-gitlab v0.15.0
github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c // indirect
github.com/xdg/stringprep v1.0.0 // indirect
gitlab.com/nyarla/go-crypt v0.0.0-20160106005555-d9a5dc2b789b // indirect
go.mongodb.org/mongo-driver v1.1.0
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073 // indirect
golang.org/x/exp v0.0.0-20200213203834-85f925bdd4d0 // indirect
golang.org/x/net v0.0.0-20200202094626-16171245cfb2
golang.org/x/tools v0.0.0-20200213224642-88e652f7a869
Expand Down
Loading

0 comments on commit b00a0cc

Please sign in to comment.