Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Update/Delete queries with returning statement #610

Merged
merged 8 commits into from
Dec 19, 2022
2 changes: 1 addition & 1 deletion acra-censor/common/matching_logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func handleDeleteStatement(query, pattern sqlparser.Statement) bool {
if !match {
return false
}
match = areEqualTableNames(queryDeleteNode.Targets, patternDeleteNode.Targets)
match = areEqualTableExprs(queryDeleteNode.Targets, patternDeleteNode.Targets)
if !match {
return false
}
Expand Down
111 changes: 90 additions & 21 deletions encryptor/queryDataEncryptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ func (encryptor *QueryDataEncryptor) encryptInsertQuery(ctx context.Context, ins
}

if encryptor.encryptor == nil {
return false, encryptor.onReturning(ctx, insert.Returning, tableName.ValueForConfig())
return false, encryptor.onReturning(ctx, insert.Returning, []sqlparser.TableExpr{&sqlparser.AliasedTableExpr{
Expr: insert.Table,
}})
}

var columnsName []string
Expand Down Expand Up @@ -319,15 +321,30 @@ func NewAliasToTableMapFromTables(tables []*AliasedTableName) AliasToTableMap {

// encryptUpdateQuery encrypt data in Update query and return true if any fields was encrypted, false if wasn't and error if error occurred
func (encryptor *QueryDataEncryptor) encryptUpdateQuery(ctx context.Context, update *sqlparser.Update, bindPlaceholders map[int]config.ColumnEncryptionSetting) (bool, error) {
tables := GetTablesWithAliases(update.TableExprs)
if !encryptor.hasTablesToEncrypt(tables) {
if len(update.TableExprs) == 0 {
return false, nil
}
if len(tables) == 0 {

fromTables := update.TableExprs

if len(update.From) != 0 {
fromTables = append(fromTables, update.From...)
}

tables := GetTablesWithAliases(fromTables)
if !encryptor.hasTablesToEncrypt(tables) {
return false, nil
}

qualifierMap := NewAliasToTableMapFromTables(tables)
firstTable := tables[0].TableName

// MySQL/MariaDB don`t support returning after update statements
// Postgres doest but expect only one table in tables expression, but also can have more tables in FROM statement
if encryptor.encryptor == nil {
return false, encryptor.onReturning(ctx, update.Returning, fromTables)
}

return encryptor.encryptUpdateExpressions(ctx, update.Exprs, firstTable, qualifierMap, bindPlaceholders)
}

Expand Down Expand Up @@ -398,26 +415,70 @@ func (encryptor *QueryDataEncryptor) onSelect(ctx context.Context, statement *sq
return false, nil
}

func (encryptor *QueryDataEncryptor) onReturning(ctx context.Context, returning sqlparser.Returning, tableName string) error {
func (encryptor *QueryDataEncryptor) onDelete(ctx context.Context, delete *sqlparser.Delete) (bool, error) {
if len(delete.TableExprs) == 0 {
return false, nil
}

fromTables := delete.TableExprs

if len(delete.Targets) != 0 {
fromTables = append(fromTables, delete.Targets...)
}

tables := GetTablesWithAliases(fromTables)
if !encryptor.hasTablesToEncrypt(tables) {
return false, nil
}

if encryptor.encryptor == nil {
return false, encryptor.onReturning(ctx, delete.Returning, fromTables)
}

return false, nil
}

func (encryptor *QueryDataEncryptor) onReturning(ctx context.Context, returning sqlparser.Returning, fromTables sqlparser.TableExprs) error {
if len(returning) == 0 {
return nil
}

schema := encryptor.schemaStore.GetTableSchema(tableName)
querySelectSettings := make([]*QueryDataItem, 0, 8)

if _, ok := returning[0].(*sqlparser.StarExpr); ok {
for _, name := range schema.Columns() {
if columnSetting := schema.GetColumnEncryptionSettings(name); columnSetting != nil {
querySelectSettings = append(querySelectSettings, &QueryDataItem{
setting: columnSetting,
tableName: tableName,
columnName: name,
})
for _, tableExp := range fromTables {
aliased, ok := tableExp.(*sqlparser.AliasedTableExpr)
if !ok {
continue
}
querySelectSettings = append(querySelectSettings, nil)

tableName, ok := aliased.Expr.(sqlparser.TableName)
if !ok {
continue
}

// if the Returning is star and we have more than one table in the query e.g.
// update table1 set did = tt.did from table2 as tt returning *
// and the table is not in the encryptor config we cant collect corresponding querySettings as we dont actual table representation
tableSchema := encryptor.schemaStore.GetTableSchema(tableName.Name.ValueForConfig())
if tableSchema == nil {
logrus.WithField("table", tableName.Name.ValueForConfig()).Info("Unable to collect querySettings for table not in encryptor config")
return errors.New("error to collect settings for unknown table")
}

for _, name := range tableSchema.Columns() {
if columnSetting := tableSchema.GetColumnEncryptionSettings(name); columnSetting != nil {
querySelectSettings = append(querySelectSettings, &QueryDataItem{
setting: columnSetting,
tableName: tableName.Name.ValueForConfig(),
columnName: name,
})
continue
}
querySelectSettings = append(querySelectSettings, nil)
}
}

clientSession := base.ClientSessionFromContext(ctx)
SaveQueryDataItemsToClientSession(clientSession, querySelectSettings)
encryptor.querySelectSettings = querySelectSettings
Expand All @@ -442,12 +503,20 @@ func (encryptor *QueryDataEncryptor) onReturning(ctx context.Context, returning
querySelectSettings = append(querySelectSettings, nil)
continue
}
rawColName := colName.Name.String()
if columnSetting := schema.GetColumnEncryptionSettings(rawColName); columnSetting != nil {

columnInfo, err := findColumnInfo(fromTables, colName, encryptor.schemaStore)
if err != nil {
querySelectSettings = append(querySelectSettings, nil)
continue
}

tableSchema := encryptor.schemaStore.GetTableSchema(columnInfo.Table)

if columnSetting := tableSchema.GetColumnEncryptionSettings(columnInfo.Name); columnSetting != nil {
querySelectSettings = append(querySelectSettings, &QueryDataItem{
setting: columnSetting,
tableName: tableName,
columnName: rawColName,
tableName: columnInfo.Table,
columnName: columnInfo.Name,
})
continue
}
Expand Down Expand Up @@ -476,9 +545,9 @@ func (encryptor *QueryDataEncryptor) OnQuery(ctx context.Context, query base.OnQ
case *sqlparser.Insert:
changed, err = encryptor.encryptInsertQuery(ctx, typedStatement, bindPlaceholders)
case *sqlparser.Update:
if encryptor.encryptor != nil {
changed, err = encryptor.encryptUpdateQuery(ctx, typedStatement, bindPlaceholders)
}
changed, err = encryptor.encryptUpdateQuery(ctx, typedStatement, bindPlaceholders)
case *sqlparser.Delete:
changed, err = encryptor.onDelete(ctx, typedStatement)
}
if err != nil {
return query, false, err
Expand Down
199 changes: 177 additions & 22 deletions encryptor/queryDataEncryptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,16 @@ schemas:
client_id: %s
- column: zone_id
zone_id: %s
`, clientIDStr, zoneIDStr)

- table: tablewithcolumnschema_2
columns: ["other_column_2", "default_client_id_2", "specified_client_id_2", "zone_id_2"]
encrypted:
- column: "default_client_id_2"
- column: specified_client_id_2
client_id: %s
- column: zone_id_2
zone_id: %s
`, clientIDStr, zoneIDStr, clientIDStr, zoneIDStr)
schemaStore, err := config.MapTableSchemaStoreFromConfig([]byte(configStr), config.UseMySQL)
if err != nil {
t.Fatalf("Can't parse config: %s", err.Error())
Expand Down Expand Up @@ -898,38 +907,184 @@ schemas:
})

t.Run("RETURNING columns with sql literals", func(t *testing.T) {
sqlparser.SetDefaultDialect(postgresql.NewPostgreSQLDialect())

returning := "1, 0 as literal, zone_id, specified_client_id, other_column, default_client_id, NULL"
query := fmt.Sprintf(`INSERT INTO TableWithColumnSchema
('zone_id', 'specified_client_id', 'other_column', 'default_client_id') VALUES (1, 1, 1, 1) RETURNING %s`, returning)
queryTemplates := []string{
"INSERT INTO TableWithColumnSchema ('zone_id', 'specified_client_id', 'other_column', 'default_client_id') VALUES (1, 1, 1, 1) RETURNING %s",
"UPDATE TableWithColumnSchema SET price = price * 1.10 WHERE price <= 99.99 RETURNING %s",
"DELETE FROM TableWithColumnSchema WHERE price <= 99.99 RETURNING %s",
}

_, _, err := encryptor.OnQuery(ctx, base.NewOnQueryObjectFromQuery(query, parser))
if err != nil {
t.Fatalf("%s", err.Error())
for _, template := range queryTemplates {
query := fmt.Sprintf(template, returning)

_, _, err := encryptor.OnQuery(ctx, base.NewOnQueryObjectFromQuery(query, parser))
if err != nil {
t.Fatalf("%s", err.Error())
}

returningColumns := strings.Split(returning, ", ")
// 1, 0 as literal, NULL
extraValuesCount := 3
if (len(columns) + extraValuesCount) != len(returningColumns) {
t.Fatalf("Incorrect encryptor.querySelectSettings length")
}

expectedNilColumns := map[int]struct{}{
0: {},
1: {},
4: {},
6: {},
}

for i := range returningColumns {
if _, ok := expectedNilColumns[i]; ok {
continue
}

setting := encryptor.querySelectSettings[i]

if returningColumns[i] != setting.columnName {
t.Fatalf("%v. Incorrect QueryDataItem \nTook: %v\nExpected: %v", i, setting.columnName, columns[i])
}
}
}
})

returningColumns := strings.Split(returning, ", ")
// 1, 0 as literal, NULL
extraValuesCount := 3
if (len(columns) + extraValuesCount) != len(returningColumns) {
t.Fatalf("Incorrect encryptor.querySelectSettings length")
t.Run("RETURNING columns with sql literals and several tables from config", func(t *testing.T) {
sqlparser.SetDefaultDialect(postgresql.NewPostgreSQLDialect())

returning := "specified_client_id, specified_client_id_2, default_client_id, default_client_id_2"
queryTemplates := []string{
"UPDATE TableWithColumnSchema SET specified_client_id = t2.specified_client_id FROM TableWithColumnSchema_2 as t2 RETURNING %s",
"DELETE FROM TableWithColumnSchema USING TableWithColumnSchema_2 WHERE specified_client_id_2 = specified_client_id RETURNING %s",
}

expectedNilColumns := map[int]struct{}{
0: {},
1: {},
4: {},
6: {},
for _, template := range queryTemplates {
query := fmt.Sprintf(template, returning)

_, _, err := encryptor.OnQuery(ctx, base.NewOnQueryObjectFromQuery(query, parser))
if err != nil {
t.Fatalf("%s", err.Error())
}

returningColumns := strings.Split(returning, ", ")
if len(encryptor.querySelectSettings) != len(returningColumns) {
t.Fatalf("Incorrect encryptor.querySelectSettings length")
}

expectedTables := map[int]string{
0: "tablewithcolumnschema",
1: "tablewithcolumnschema_2",
2: "tablewithcolumnschema",
3: "tablewithcolumnschema_2",
}

for i := range returningColumns {
setting := encryptor.querySelectSettings[i]
if setting == nil {
t.Fatalf("expected setting not to be nil")
}

if setting.tableName != expectedTables[i] {
t.Fatalf("Unexpected setting.tableName, expected %s but got %s", expectedTables[i], setting.tableName)
}
}
}
})

for i := range returningColumns {
if _, ok := expectedNilColumns[i]; ok {
continue
t.Run("RETURNING columns with sql literals and several tables", func(t *testing.T) {
sqlparser.SetDefaultDialect(postgresql.NewPostgreSQLDialect())

returning := "specified_client_id_2, specified_unknown_column, default_client_id_2, default_unknown_column"
Zhaars marked this conversation as resolved.
Show resolved Hide resolved
queryTemplates := []string{
"UPDATE UnknownTable SET specified_client_id = t2.specified_client_id FROM TableWithColumnSchema_2 as t2 RETURNING %s",
"UPDATE TableWithColumnSchema_2 as t2 SET specified_client_id = t2.specified_client_id FROM UnknownTable RETURNING %s",
"DELETE FROM UnknownTable USING TableWithColumnSchema_2 as t2 WHERE t2.specified_client_id = default_unknown_column RETURNING %s",
"DELETE FROM TableWithColumnSchema_2 USING UnknownTable WHERE specified_client_id = default_unknown_column RETURNING %s",
}

for _, template := range queryTemplates {
query := fmt.Sprintf(template, returning)

_, _, err := encryptor.OnQuery(ctx, base.NewOnQueryObjectFromQuery(query, parser))
if err != nil {
t.Fatalf("%s", err.Error())
}

setting := encryptor.querySelectSettings[i]
returningColumns := strings.Split(returning, ", ")
if len(encryptor.querySelectSettings) != len(returningColumns) {
t.Fatalf("Incorrect encryptor.querySelectSettings length")
}

if returningColumns[i] != setting.columnName {
t.Fatalf("%v. Incorrect QueryDataItem \nTook: %v\nExpected: %v", i, setting.columnName, columns[i])
tableFromConfig := "tablewithcolumnschema_2"
expectedTables := map[int]*string{
0: &tableFromConfig,
1: nil,
2: &tableFromConfig,
3: nil,
}

for i := range returningColumns {
setting := encryptor.querySelectSettings[i]

if expectedTables[i] == nil && setting != nil {
t.Fatalf("Expected setting to be nil, but got %s", setting)
}

if expectedTables[i] != nil && setting == nil {
t.Fatalf("Expected setting not to be nil, but got nil")
}

if table := expectedTables[i]; table != nil {
if *table != setting.tableName {
t.Fatalf("Unexpected setting table name, want %s but got %s", *table, setting.tableName)
}
}
}
}
})

t.Run("RETURNING with star and several tables", func(t *testing.T) {
sqlparser.SetDefaultDialect(postgresql.NewPostgreSQLDialect())

returning := "*"
queryTemplates := []string{
"UPDATE TableWithColumnSchema SET specified_client_id = t2.specified_client_id FROM TableWithColumnSchema_2 as t2 RETURNING %s",
"DELETE FROM TableWithColumnSchema USING TableWithColumnSchema_2 as t2 WHERE did = t2.did RETURNING %s",
}

tableSchema := schemaStore.GetTableSchema("tablewithcolumnschema")
table2Schema := schemaStore.GetTableSchema("tablewithcolumnschema_2")

expectSettingNumber := len(tableSchema.Columns()) + len(table2Schema.Columns())

for _, template := range queryTemplates {
query := fmt.Sprintf(template, returning)

_, _, err := encryptor.OnQuery(ctx, base.NewOnQueryObjectFromQuery(query, parser))
if err != nil {
t.Fatalf("%s", err.Error())
}

expectedNilColumns := map[int]struct{}{
0: {},
4: {},
}

if expectSettingNumber != len(encryptor.querySelectSettings) {
t.Fatalf("Incorrect number of encryptor.querySelectSettings")
}

for i := 0; i < expectSettingNumber; i++ {
if _, ok := expectedNilColumns[i]; ok {
setting := encryptor.querySelectSettings[i]

if setting != nil {
t.Fatalf("Expected nil setting, but got not %s", setting.columnName)
}
}
}
}
})
Expand Down