Skip to content

Commit

Permalink
Extend Update/Delete queries with returning statement (#610)
Browse files Browse the repository at this point in the history
Extend Update/Delete queries with Returning
  • Loading branch information
Zhaars committed Dec 19, 2022
1 parent 0b47fca commit 0517312
Show file tree
Hide file tree
Showing 9 changed files with 2,592 additions and 2,118 deletions.
2 changes: 1 addition & 1 deletion acra-censor/common/matching_logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func handleDeleteStatement(query, pattern sqlparser.Statement) bool {
if !match {
return false
}
match = areEqualTableNames(queryDeleteNode.Targets, patternDeleteNode.Targets)
match = areEqualTableExprs(queryDeleteNode.Targets, patternDeleteNode.Targets)
if !match {
return false
}
Expand Down
111 changes: 90 additions & 21 deletions encryptor/queryDataEncryptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ func (encryptor *QueryDataEncryptor) encryptInsertQuery(ctx context.Context, ins
}

if encryptor.encryptor == nil {
return false, encryptor.onReturning(ctx, insert.Returning, tableName.ValueForConfig())
return false, encryptor.onReturning(ctx, insert.Returning, []sqlparser.TableExpr{&sqlparser.AliasedTableExpr{
Expr: insert.Table,
}})
}

var columnsName []string
Expand Down Expand Up @@ -319,15 +321,30 @@ func NewAliasToTableMapFromTables(tables []*AliasedTableName) AliasToTableMap {

// encryptUpdateQuery encrypt data in Update query and return true if any fields was encrypted, false if wasn't and error if error occurred
func (encryptor *QueryDataEncryptor) encryptUpdateQuery(ctx context.Context, update *sqlparser.Update, bindPlaceholders map[int]config.ColumnEncryptionSetting) (bool, error) {
tables := GetTablesWithAliases(update.TableExprs)
if !encryptor.hasTablesToEncrypt(tables) {
if len(update.TableExprs) == 0 {
return false, nil
}
if len(tables) == 0 {

fromTables := update.TableExprs

if len(update.From) != 0 {
fromTables = append(fromTables, update.From...)
}

tables := GetTablesWithAliases(fromTables)
if !encryptor.hasTablesToEncrypt(tables) {
return false, nil
}

qualifierMap := NewAliasToTableMapFromTables(tables)
firstTable := tables[0].TableName

// MySQL/MariaDB don`t support returning after update statements
// Postgres doest but expect only one table in tables expression, but also can have more tables in FROM statement
if encryptor.encryptor == nil {
return false, encryptor.onReturning(ctx, update.Returning, fromTables)
}

return encryptor.encryptUpdateExpressions(ctx, update.Exprs, firstTable, qualifierMap, bindPlaceholders)
}

Expand Down Expand Up @@ -398,26 +415,70 @@ func (encryptor *QueryDataEncryptor) onSelect(ctx context.Context, statement *sq
return false, nil
}

func (encryptor *QueryDataEncryptor) onReturning(ctx context.Context, returning sqlparser.Returning, tableName string) error {
func (encryptor *QueryDataEncryptor) onDelete(ctx context.Context, delete *sqlparser.Delete) (bool, error) {
if len(delete.TableExprs) == 0 {
return false, nil
}

fromTables := delete.TableExprs

if len(delete.Targets) != 0 {
fromTables = append(fromTables, delete.Targets...)
}

tables := GetTablesWithAliases(fromTables)
if !encryptor.hasTablesToEncrypt(tables) {
return false, nil
}

if encryptor.encryptor == nil {
return false, encryptor.onReturning(ctx, delete.Returning, fromTables)
}

return false, nil
}

func (encryptor *QueryDataEncryptor) onReturning(ctx context.Context, returning sqlparser.Returning, fromTables sqlparser.TableExprs) error {
if len(returning) == 0 {
return nil
}

schema := encryptor.schemaStore.GetTableSchema(tableName)
querySelectSettings := make([]*QueryDataItem, 0, 8)

if _, ok := returning[0].(*sqlparser.StarExpr); ok {
for _, name := range schema.Columns() {
if columnSetting := schema.GetColumnEncryptionSettings(name); columnSetting != nil {
querySelectSettings = append(querySelectSettings, &QueryDataItem{
setting: columnSetting,
tableName: tableName,
columnName: name,
})
for _, tableExp := range fromTables {
aliased, ok := tableExp.(*sqlparser.AliasedTableExpr)
if !ok {
continue
}
querySelectSettings = append(querySelectSettings, nil)

tableName, ok := aliased.Expr.(sqlparser.TableName)
if !ok {
continue
}

// if the Returning is star and we have more than one table in the query e.g.
// update table1 set did = tt.did from table2 as tt returning *
// and the table is not in the encryptor config we cant collect corresponding querySettings as we dont actual table representation
tableSchema := encryptor.schemaStore.GetTableSchema(tableName.Name.ValueForConfig())
if tableSchema == nil {
logrus.WithField("table", tableName.Name.ValueForConfig()).Info("Unable to collect querySettings for table not in encryptor config")
return errors.New("error to collect settings for unknown table")
}

for _, name := range tableSchema.Columns() {
if columnSetting := tableSchema.GetColumnEncryptionSettings(name); columnSetting != nil {
querySelectSettings = append(querySelectSettings, &QueryDataItem{
setting: columnSetting,
tableName: tableName.Name.ValueForConfig(),
columnName: name,
})
continue
}
querySelectSettings = append(querySelectSettings, nil)
}
}

clientSession := base.ClientSessionFromContext(ctx)
SaveQueryDataItemsToClientSession(clientSession, querySelectSettings)
encryptor.querySelectSettings = querySelectSettings
Expand All @@ -442,12 +503,20 @@ func (encryptor *QueryDataEncryptor) onReturning(ctx context.Context, returning
querySelectSettings = append(querySelectSettings, nil)
continue
}
rawColName := colName.Name.String()
if columnSetting := schema.GetColumnEncryptionSettings(rawColName); columnSetting != nil {

columnInfo, err := findColumnInfo(fromTables, colName, encryptor.schemaStore)
if err != nil {
querySelectSettings = append(querySelectSettings, nil)
continue
}

tableSchema := encryptor.schemaStore.GetTableSchema(columnInfo.Table)

if columnSetting := tableSchema.GetColumnEncryptionSettings(columnInfo.Name); columnSetting != nil {
querySelectSettings = append(querySelectSettings, &QueryDataItem{
setting: columnSetting,
tableName: tableName,
columnName: rawColName,
tableName: columnInfo.Table,
columnName: columnInfo.Name,
})
continue
}
Expand Down Expand Up @@ -476,9 +545,9 @@ func (encryptor *QueryDataEncryptor) OnQuery(ctx context.Context, query base.OnQ
case *sqlparser.Insert:
changed, err = encryptor.encryptInsertQuery(ctx, typedStatement, bindPlaceholders)
case *sqlparser.Update:
if encryptor.encryptor != nil {
changed, err = encryptor.encryptUpdateQuery(ctx, typedStatement, bindPlaceholders)
}
changed, err = encryptor.encryptUpdateQuery(ctx, typedStatement, bindPlaceholders)
case *sqlparser.Delete:
changed, err = encryptor.onDelete(ctx, typedStatement)
}
if err != nil {
return query, false, err
Expand Down

0 comments on commit 0517312

Please sign in to comment.