diff --git a/db.go b/db.go index 4e07588..a88ed5f 100644 --- a/db.go +++ b/db.go @@ -101,7 +101,11 @@ func (m *DbMap) createIndexImpl(ctx context.Context, dialect reflect.Type, s.WriteString(" unique") } s.WriteString(" index") - s.WriteString(fmt.Sprintf(" %s on %s", index.IndexName, table.TableName)) + s.WriteString(fmt.Sprintf( + " %s on %s", + m.Dialect.QuoteField(index.IndexName), + m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName), + )) if dname := dialect.Name(); dname == "PostgresDialect" && index.IndexType != "" { s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType)) } @@ -129,10 +133,14 @@ func (t *TableMap) DropIndex(ctx context.Context, name string) error { for _, idx := range t.indexes { if idx.IndexName == name { s := bytes.Buffer{} - s.WriteString(fmt.Sprintf("DROP INDEX %s", idx.IndexName)) + s.WriteString(fmt.Sprintf("DROP INDEX %s", t.dbmap.Dialect.QuoteField(idx.IndexName))) if dname := dialect.Name(); dname == "MySQLDialect" { - s.WriteString(fmt.Sprintf(" %s %s", t.dbmap.Dialect.DropIndexSuffix(), t.TableName)) + s.WriteString(fmt.Sprintf( + " %s %s", + t.dbmap.Dialect.DropIndexSuffix(), + t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName), + )) } s.WriteString(";") _, e := t.dbmap.ExecContext(ctx, s.String()) diff --git a/dialect_mysql.go b/dialect_mysql.go index 1dfc2be..adead48 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -146,7 +146,7 @@ func (d MySQLDialect) InsertAutoIncr(ctx context.Context, exec SqlExecutor, inse } func (d MySQLDialect) QuoteField(f string) string { - return "`" + f + "`" + return "`" + strings.ReplaceAll(f, "`", "``") + "`" } func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { @@ -154,7 +154,7 @@ func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { return d.QuoteField(table) } - return schema + "." + d.QuoteField(table) + return d.QuoteField(schema) + "." + d.QuoteField(table) } func (d MySQLDialect) IfSchemaNotExists(command, schema string) string { diff --git a/dialect_mysql_test.go b/dialect_mysql_test.go index 966162e..660ae85 100644 --- a/dialect_mysql_test.go +++ b/dialect_mysql_test.go @@ -141,6 +141,7 @@ func TestMySQLDialect(t *testing.T) { o.Spec("QuoteField", func(tcx testContext) { tcx.expect(tcx.dialect.QuoteField("foo")).To(matchers.Equal("`foo`")) + tcx.expect(tcx.dialect.QuoteField("fo`o")).To(matchers.Equal("`fo``o`")) }) o.Group("QuotedTableForQuery", func() { @@ -149,7 +150,8 @@ func TestMySQLDialect(t *testing.T) { }) o.Spec("with a supplied schema", func(tcx testContext) { - tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("foo.`bar`")) + tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("`foo`.`bar`")) + tcx.expect(tcx.dialect.QuotedTableForQuery("fo`o", "ba`r")).To(matchers.Equal("`fo``o`.`ba``r`")) }) }) diff --git a/dialect_postgres.go b/dialect_postgres.go index 937f81e..02113d1 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -124,9 +124,9 @@ func (d PostgresDialect) InsertAutoIncrToTarget(ctx context.Context, exec SqlExe func (d PostgresDialect) QuoteField(f string) string { if d.LowercaseFields { - return `"` + strings.ToLower(f) + `"` + f = strings.ToLower(f) } - return `"` + f + `"` + return `"` + strings.ReplaceAll(f, `"`, `""`) + `"` } func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string { @@ -134,7 +134,7 @@ func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string return d.QuoteField(table) } - return schema + "." + d.QuoteField(table) + return d.QuoteField(schema) + "." + d.QuoteField(table) } func (d PostgresDialect) IfSchemaNotExists(command, schema string) string { diff --git a/dialect_postgres_test.go b/dialect_postgres_test.go index 4a2f674..25e4604 100644 --- a/dialect_postgres_test.go +++ b/dialect_postgres_test.go @@ -120,6 +120,7 @@ func TestPostgresDialect(t *testing.T) { o.Spec("By default, case is preserved", func(tcx postgresTestContext) { tcx.expect(tcx.dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`)) tcx.expect(tcx.dialect.QuoteField("bar")).To(matchers.Equal(`"bar"`)) + tcx.expect(tcx.dialect.QuoteField(`Fo"o`)).To(matchers.Equal(`"Fo""o"`)) }) o.Group("With LowercaseFields set to true", func() { @@ -130,6 +131,7 @@ func TestPostgresDialect(t *testing.T) { o.Spec("fields are lowercased", func(tcx postgresTestContext) { tcx.expect(tcx.dialect.QuoteField("Foo")).To(matchers.Equal(`"foo"`)) + tcx.expect(tcx.dialect.QuoteField(`Fo"O`)).To(matchers.Equal(`"fo""o"`)) }) }) }) @@ -140,7 +142,8 @@ func TestPostgresDialect(t *testing.T) { }) o.Spec("with a supplied schema", func(tcx postgresTestContext) { - tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`foo."bar"`)) + tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`"foo"."bar"`)) + tcx.expect(tcx.dialect.QuotedTableForQuery(`fo"o`, `ba"r`)).To(matchers.Equal(`"fo""o"."ba""r"`)) }) }) diff --git a/dialect_sqlite.go b/dialect_sqlite.go index a0ba40d..d422174 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "reflect" + "strings" ) type SqliteDialect struct { @@ -92,7 +93,7 @@ func (d SqliteDialect) InsertAutoIncr(ctx context.Context, exec SqlExecutor, ins } func (d SqliteDialect) QuoteField(f string) string { - return `"` + f + `"` + return `"` + strings.ReplaceAll(f, `"`, `""`) + `"` } // sqlite does not have schemas like PostgreSQL does, so just escape it like normal diff --git a/identifier_quote_test.go b/identifier_quote_test.go new file mode 100644 index 0000000..f250f85 --- /dev/null +++ b/identifier_quote_test.go @@ -0,0 +1,196 @@ +package borp + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "sync" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +type identifierCapturedExec struct { + query string + args []driver.NamedValue +} + +var identifierCaptureState = struct { + sync.Mutex + registered sync.Once + execs []identifierCapturedExec +}{} + +type identifierCaptureDriver struct{} + +func (identifierCaptureDriver) Open(string) (driver.Conn, error) { + return identifierCaptureConn{}, nil +} + +type identifierCaptureConn struct{} + +func (identifierCaptureConn) Prepare(string) (driver.Stmt, error) { + return nil, errors.New("identifier capture driver does not prepare statements") +} + +func (identifierCaptureConn) Close() error { + return nil +} + +func (identifierCaptureConn) Begin() (driver.Tx, error) { + return nil, errors.New("identifier capture driver does not begin transactions") +} + +func (identifierCaptureConn) ExecContext( + _ context.Context, + query string, + args []driver.NamedValue, +) (driver.Result, error) { + argsCopy := append([]driver.NamedValue(nil), args...) + + identifierCaptureState.Lock() + defer identifierCaptureState.Unlock() + identifierCaptureState.execs = append(identifierCaptureState.execs, identifierCapturedExec{ + query: query, + args: argsCopy, + }) + return driver.RowsAffected(0), nil +} + +func newIdentifierCaptureDbMap(t *testing.T, dialect Dialect) *DbMap { + t.Helper() + + identifierCaptureState.registered.Do(func() { + sql.Register("borp_identifier_capture", identifierCaptureDriver{}) + }) + + identifierCaptureState.Lock() + identifierCaptureState.execs = nil + identifierCaptureState.Unlock() + + db, err := sql.Open("borp_identifier_capture", "") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + closeErr := db.Close() + if closeErr != nil { + t.Fatal(closeErr) + } + }) + + return &DbMap{Db: db, Dialect: dialect} +} + +func identifierCapturedExecs() []identifierCapturedExec { + identifierCaptureState.Lock() + defer identifierCaptureState.Unlock() + return append([]identifierCapturedExec(nil), identifierCaptureState.execs...) +} + +func TestSqliteDialectEscapesIdentifierQuotes(t *testing.T) { + dialect := SqliteDialect{} + got := dialect.QuoteField(`fo"o`) + want := `"fo""o"` + if got != want { + t.Fatalf("QuoteField() = %q, want %q", got, want) + } + got = dialect.QuotedTableForQuery("", `ta"ble`) + want = `"ta""ble"` + if got != want { + t.Fatalf("QuotedTableForQuery() = %q, want %q", got, want) + } +} + +func TestQuotedTableNameCannotRewriteUpdateTarget(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE victim (id integer primary key, value text, admin integer)") + if err != nil { + t.Fatal(err) + } + _, err = db.Exec("INSERT INTO victim (id, value, admin) VALUES (1, 'unchanged', 0)") + if err != nil { + t.Fatal(err) + } + + type row struct { + ID int64 `db:"ID"` + Value string `db:"Value"` + } + + dbmap := &DbMap{Db: db, Dialect: SqliteDialect{}} + injectedTable := `victim" SET admin = 1 WHERE ? <> ? -- ` + dbmap.AddTableWithName(row{}, injectedTable).SetKeys(false, "ID") + + _, err = dbmap.Update(context.Background(), &row{ID: 1, Value: "unused"}) + if err == nil { + t.Fatal("Update succeeded for escaped malicious table name") + } + + var admin int + err = db.QueryRow("SELECT admin FROM victim WHERE id = 1").Scan(&admin) + if err != nil { + t.Fatal(err) + } + if admin != 0 { + t.Fatalf("victim.admin = %d, want 0", admin) + } +} + +func TestCreateIndexQuotesIdentifierMetadata(t *testing.T) { + type indexedRow struct { + ID int64 `db:"ID"` + } + + dbmap := newIdentifierCaptureDbMap(t, PostgresDialect{}) + table := dbmap.AddTableWithNameAndSchema(indexedRow{}, `sche"ma`, `security"rows`) + table.SetKeys(false, "ID") + table.AddIndex(`idx"name`, "btree", []string{"ID"}) + + err := dbmap.CreateIndex(context.Background()) + if err != nil { + t.Fatal(err) + } + + execs := identifierCapturedExecs() + if len(execs) != 1 { + t.Fatalf("expected one captured exec, got %d: %+v", len(execs), execs) + } + + want := `create index "idx""name" on "sche""ma"."security""rows" using btree ("ID");` + if execs[0].query != want { + t.Fatalf("generated %q, want %q", execs[0].query, want) + } +} + +func TestDropIndexQuotesIdentifierMetadata(t *testing.T) { + type indexedRow struct { + ID int64 `db:"ID"` + } + + dbmap := newIdentifierCaptureDbMap(t, MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}) + table := dbmap.AddTableWithNameAndSchema(indexedRow{}, "sche`ma", "security`rows") + table.SetKeys(false, "ID") + table.AddIndex("idx`name", "Btree", []string{"ID"}) + + err := table.DropIndex(context.Background(), "idx`name") + if err != nil { + t.Fatal(err) + } + + execs := identifierCapturedExecs() + if len(execs) != 1 { + t.Fatalf("expected one captured exec, got %d: %+v", len(execs), execs) + } + + want := "DROP INDEX `idx``name` on `sche``ma`.`security``rows`;" + if execs[0].query != want { + t.Fatalf("generated %q, want %q", execs[0].query, want) + } +}