Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Update/Delete queries with returning statement #610

Merged
merged 8 commits into from
Dec 19, 2022
30 changes: 27 additions & 3 deletions encryptor/queryDataEncryptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,16 @@ func (encryptor *QueryDataEncryptor) encryptUpdateQuery(ctx context.Context, upd
if len(tables) == 0 {
return false, nil
}

qualifierMap := NewAliasToTableMapFromTables(tables)
firstTable := tables[0].TableName

// MySQL/MariaDB dont support returning after update statements
// Postgres doest but expect only one table in tables expression, so we can take the firstTable for returning matching
Zhaars marked this conversation as resolved.
Show resolved Hide resolved
if encryptor.encryptor == nil {
return false, encryptor.onReturning(ctx, update.Returning, firstTable.Name.ValueForConfig())
}

return encryptor.encryptUpdateExpressions(ctx, update.Exprs, firstTable, qualifierMap, bindPlaceholders)
}

Expand Down Expand Up @@ -398,6 +406,22 @@ func (encryptor *QueryDataEncryptor) onSelect(ctx context.Context, statement *sq
return false, nil
}

func (encryptor *QueryDataEncryptor) onDelete(ctx context.Context, delete *sqlparser.Delete) (bool, error) {
tables := GetTablesWithAliases(delete.TableExprs)
if !encryptor.hasTablesToEncrypt(tables) {
return false, nil
}
if len(tables) == 0 {
return false, nil
}

if encryptor.encryptor == nil {
return false, encryptor.onReturning(ctx, delete.Returning, tables[0].TableName.Name.ValueForConfig())
Zhaars marked this conversation as resolved.
Show resolved Hide resolved
}

return false, nil
}

func (encryptor *QueryDataEncryptor) onReturning(ctx context.Context, returning sqlparser.Returning, tableName string) error {
if len(returning) == 0 {
return nil
Expand Down Expand Up @@ -476,9 +500,9 @@ func (encryptor *QueryDataEncryptor) OnQuery(ctx context.Context, query base.OnQ
case *sqlparser.Insert:
changed, err = encryptor.encryptInsertQuery(ctx, typedStatement, bindPlaceholders)
case *sqlparser.Update:
if encryptor.encryptor != nil {
changed, err = encryptor.encryptUpdateQuery(ctx, typedStatement, bindPlaceholders)
}
changed, err = encryptor.encryptUpdateQuery(ctx, typedStatement, bindPlaceholders)
case *sqlparser.Delete:
changed, err = encryptor.onDelete(ctx, typedStatement)
}
if err != nil {
return query, false, err
Expand Down
57 changes: 33 additions & 24 deletions encryptor/queryDataEncryptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -898,38 +898,47 @@ schemas:
})

t.Run("RETURNING columns with sql literals", func(t *testing.T) {
returning := "1, 0 as literal, zone_id, specified_client_id, other_column, default_client_id, NULL"
query := fmt.Sprintf(`INSERT INTO TableWithColumnSchema
('zone_id', 'specified_client_id', 'other_column', 'default_client_id') VALUES (1, 1, 1, 1) RETURNING %s`, returning)
sqlparser.SetDefaultDialect(postgresql.NewPostgreSQLDialect())

_, _, err := encryptor.OnQuery(ctx, base.NewOnQueryObjectFromQuery(query, parser))
if err != nil {
t.Fatalf("%s", err.Error())
returning := "1, 0 as literal, zone_id, specified_client_id, other_column, default_client_id, NULL"
queryTemplates := []string{
"INSERT INTO TableWithColumnSchema ('zone_id', 'specified_client_id', 'other_column', 'default_client_id') VALUES (1, 1, 1, 1) RETURNING %s",
"UPDATE TableWithColumnSchema SET price = price * 1.10 WHERE price <= 99.99 RETURNING %s",
"DELETE FROM TableWithColumnSchema WHERE price <= 99.99 RETURNING %s",
}

returningColumns := strings.Split(returning, ", ")
// 1, 0 as literal, NULL
extraValuesCount := 3
if (len(columns) + extraValuesCount) != len(returningColumns) {
t.Fatalf("Incorrect encryptor.querySelectSettings length")
}
for _, template := range queryTemplates {
query := fmt.Sprintf(template, returning)

expectedNilColumns := map[int]struct{}{
0: {},
1: {},
4: {},
6: {},
}
_, _, err := encryptor.OnQuery(ctx, base.NewOnQueryObjectFromQuery(query, parser))
if err != nil {
t.Fatalf("%s", err.Error())
}

for i := range returningColumns {
if _, ok := expectedNilColumns[i]; ok {
continue
returningColumns := strings.Split(returning, ", ")
// 1, 0 as literal, NULL
extraValuesCount := 3
if (len(columns) + extraValuesCount) != len(returningColumns) {
t.Fatalf("Incorrect encryptor.querySelectSettings length")
}

setting := encryptor.querySelectSettings[i]
expectedNilColumns := map[int]struct{}{
0: {},
1: {},
4: {},
6: {},
}

if returningColumns[i] != setting.columnName {
t.Fatalf("%v. Incorrect QueryDataItem \nTook: %v\nExpected: %v", i, setting.columnName, columns[i])
for i := range returningColumns {
if _, ok := expectedNilColumns[i]; ok {
continue
}

setting := encryptor.querySelectSettings[i]

if returningColumns[i] != setting.columnName {
t.Fatalf("%v. Incorrect QueryDataItem \nTook: %v\nExpected: %v", i, setting.columnName, columns[i])
}
}
}
})
Expand Down
2 changes: 2 additions & 0 deletions sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ type Update struct {
Where *Where
OrderBy OrderBy
Limit *Limit
Returning Returning
}

// Delete represents a DELETE statement.
Expand All @@ -367,6 +368,7 @@ type Delete struct {
Where *Where
OrderBy OrderBy
Limit *Limit
Returning Returning
}

// Set represents a SET statement.
Expand Down
6 changes: 3 additions & 3 deletions sqlparser/ast_methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,9 @@ func (node *Insert) walkSubtree(visit Visit) error {

// Format formats the node.
func (node *Update) Format(buf *TrackedBuffer) {
buf.Myprintf("update %v%v set %v%v%v%v",
buf.Myprintf("update %v%v set %v%v%v%v%v",
node.Comments, node.TableExprs,
node.Exprs, node.Where, node.OrderBy, node.Limit)
node.Exprs, node.Where, node.OrderBy, node.Limit, node.Returning)
}

func (node *Update) walkSubtree(visit Visit) error {
Expand All @@ -335,7 +335,7 @@ func (node *Delete) Format(buf *TrackedBuffer) {
if node.Targets != nil {
buf.Myprintf("%v ", node.Targets)
}
buf.Myprintf("from %v%v%v%v%v", node.TableExprs, node.Partitions, node.Where, node.OrderBy, node.Limit)
buf.Myprintf("from %v%v%v%v%v%v", node.TableExprs, node.Partitions, node.Where, node.OrderBy, node.Limit, node.Returning)
}

func (node *Delete) walkSubtree(visit Visit) error {
Expand Down
12 changes: 12 additions & 0 deletions sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,14 @@ var (
input: "SELECT * FROM dual WHERE val NOT ILIKE 'test%'",
output: "select * from dual where val not ilike 'test%'",
dialect: postgresql.NewPostgreSQLDialect(),
}, {
input: "delete from dual where price <= 99.99 returning 1, 0 as literal, zone_id, specified_client_id, other_column, default_client_id, null",
dialect: postgresql.NewPostgreSQLDialect(),
}, {
input: "delete from dual where price <= 99.99 returning 1, 0 as literal, zone_id, specified_client_id, other_column, default_client_id, null",
}, {
input: "update dual set price = price * 1.10 where price <= 99.99 returning 1, 0 as literal, zone_id, specified_client_id, other_column, default_client_id, null",
dialect: postgresql.NewPostgreSQLDialect(),
},
}
)
Expand Down Expand Up @@ -2252,6 +2260,10 @@ var (
input: "SELECT * FROM dual WHERE val NOT ILIKE 'test%'",
output: "MySQL dialect doesn't support `ILIKE` statement at position 47",
dialect: mysql.NewMySQLDialect(),
}, {
input: "UPDATE dual SET price = price * 1.1 RETURNING 1, 0 as literal",
output: "MySQL/MariaDB dialect doesn't support returning with update statement at position 62",
dialect: mysql.NewMySQLDialect(),
},
}
)
Expand Down