Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Update/Delete queries with returning statement #610

Merged
merged 8 commits into from
Dec 19, 2022
2 changes: 1 addition & 1 deletion acra-censor/common/matching_logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func handleDeleteStatement(query, pattern sqlparser.Statement) bool {
if !match {
return false
}
match = areEqualTableNames(queryDeleteNode.Targets, patternDeleteNode.Targets)
match = areEqualTableExprs(queryDeleteNode.Targets, patternDeleteNode.Targets)
if !match {
return false
}
Expand Down
111 changes: 90 additions & 21 deletions encryptor/queryDataEncryptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ func (encryptor *QueryDataEncryptor) encryptInsertQuery(ctx context.Context, ins
}

if encryptor.encryptor == nil {
return false, encryptor.onReturning(ctx, insert.Returning, tableName.ValueForConfig())
return false, encryptor.onReturning(ctx, insert.Returning, []sqlparser.TableExpr{&sqlparser.AliasedTableExpr{
Expr: insert.Table,
}})
}

var columnsName []string
Expand Down Expand Up @@ -319,15 +321,30 @@ func NewAliasToTableMapFromTables(tables []*AliasedTableName) AliasToTableMap {

// encryptUpdateQuery encrypt data in Update query and return true if any fields was encrypted, false if wasn't and error if error occurred
func (encryptor *QueryDataEncryptor) encryptUpdateQuery(ctx context.Context, update *sqlparser.Update, bindPlaceholders map[int]config.ColumnEncryptionSetting) (bool, error) {
tables := GetTablesWithAliases(update.TableExprs)
if !encryptor.hasTablesToEncrypt(tables) {
if len(update.TableExprs) == 0 {
return false, nil
}
if len(tables) == 0 {

fromTables := update.TableExprs

if len(update.From) != 0 {
fromTables = append(fromTables, update.From...)
}

tables := GetTablesWithAliases(fromTables)
if !encryptor.hasTablesToEncrypt(tables) {
return false, nil
}

qualifierMap := NewAliasToTableMapFromTables(tables)
firstTable := tables[0].TableName

// MySQL/MariaDB don`t support returning after update statements
// Postgres doest but expect only one table in tables expression, but also can have more tables in FROM statement
if encryptor.encryptor == nil {
return false, encryptor.onReturning(ctx, update.Returning, fromTables)
}

return encryptor.encryptUpdateExpressions(ctx, update.Exprs, firstTable, qualifierMap, bindPlaceholders)
}

Expand Down Expand Up @@ -398,26 +415,70 @@ func (encryptor *QueryDataEncryptor) onSelect(ctx context.Context, statement *sq
return false, nil
}

func (encryptor *QueryDataEncryptor) onReturning(ctx context.Context, returning sqlparser.Returning, tableName string) error {
func (encryptor *QueryDataEncryptor) onDelete(ctx context.Context, delete *sqlparser.Delete) (bool, error) {
if len(delete.TableExprs) == 0 {
return false, nil
}

fromTables := delete.TableExprs

if len(delete.Targets) != 0 {
fromTables = append(fromTables, delete.Targets...)
}

tables := GetTablesWithAliases(fromTables)
if !encryptor.hasTablesToEncrypt(tables) {
return false, nil
}

if encryptor.encryptor == nil {
return false, encryptor.onReturning(ctx, delete.Returning, fromTables)
}

return false, nil
}

func (encryptor *QueryDataEncryptor) onReturning(ctx context.Context, returning sqlparser.Returning, fromTables sqlparser.TableExprs) error {
if len(returning) == 0 {
return nil
}

schema := encryptor.schemaStore.GetTableSchema(tableName)
querySelectSettings := make([]*QueryDataItem, 0, 8)

if _, ok := returning[0].(*sqlparser.StarExpr); ok {
for _, name := range schema.Columns() {
if columnSetting := schema.GetColumnEncryptionSettings(name); columnSetting != nil {
querySelectSettings = append(querySelectSettings, &QueryDataItem{
setting: columnSetting,
tableName: tableName,
columnName: name,
})
for _, tableExp := range fromTables {
aliased, ok := tableExp.(*sqlparser.AliasedTableExpr)
if !ok {
continue
}
querySelectSettings = append(querySelectSettings, nil)

tableName, ok := aliased.Expr.(sqlparser.TableName)
if !ok {
continue
}

// if the Returning is star and we have more than one table in the query e.g.
// update table1 set did = tt.did from table2 as tt returning *
// and the table is not in the encryptor config we cant collect corresponding querySettings as we dont actual table representation
tableSchema := encryptor.schemaStore.GetTableSchema(tableName.Name.ValueForConfig())
if tableSchema == nil {
logrus.WithField("table", tableName.Name.ValueForConfig()).Info("Unable to collect querySettings for table not in encryptor config")
return errors.New("error to collect settings for unknown table")
}

for _, name := range tableSchema.Columns() {
if columnSetting := tableSchema.GetColumnEncryptionSettings(name); columnSetting != nil {
querySelectSettings = append(querySelectSettings, &QueryDataItem{
setting: columnSetting,
tableName: tableName.Name.ValueForConfig(),
columnName: name,
})
continue
}
querySelectSettings = append(querySelectSettings, nil)
}
}

clientSession := base.ClientSessionFromContext(ctx)
SaveQueryDataItemsToClientSession(clientSession, querySelectSettings)
encryptor.querySelectSettings = querySelectSettings
Expand All @@ -442,12 +503,20 @@ func (encryptor *QueryDataEncryptor) onReturning(ctx context.Context, returning
querySelectSettings = append(querySelectSettings, nil)
continue
}
rawColName := colName.Name.String()
if columnSetting := schema.GetColumnEncryptionSettings(rawColName); columnSetting != nil {

columnInfo, err := findColumnInfo(fromTables, colName, encryptor.schemaStore)
if err != nil {
querySelectSettings = append(querySelectSettings, nil)
continue
}

tableSchema := encryptor.schemaStore.GetTableSchema(columnInfo.Table)

if columnSetting := tableSchema.GetColumnEncryptionSettings(columnInfo.Name); columnSetting != nil {
querySelectSettings = append(querySelectSettings, &QueryDataItem{
setting: columnSetting,
tableName: tableName,
columnName: rawColName,
tableName: columnInfo.Table,
columnName: columnInfo.Name,
})
continue
}
Expand Down Expand Up @@ -476,9 +545,9 @@ func (encryptor *QueryDataEncryptor) OnQuery(ctx context.Context, query base.OnQ
case *sqlparser.Insert:
changed, err = encryptor.encryptInsertQuery(ctx, typedStatement, bindPlaceholders)
case *sqlparser.Update:
if encryptor.encryptor != nil {
changed, err = encryptor.encryptUpdateQuery(ctx, typedStatement, bindPlaceholders)
}
changed, err = encryptor.encryptUpdateQuery(ctx, typedStatement, bindPlaceholders)
case *sqlparser.Delete:
changed, err = encryptor.onDelete(ctx, typedStatement)
}
if err != nil {
return query, false, err
Expand Down