Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix search with aliased table #602

Merged
merged 3 commits into from Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
210 changes: 68 additions & 142 deletions encryptor/searchable_query_filter.go
Expand Up @@ -63,65 +63,24 @@ func (filter *SearchableQueryFilter) FilterSearchableComparisons(statement sqlpa
return nil
}

defaultTable, aliasedTables := filter.filterInterestingTables(tableExps)
if len(aliasedTables) == 0 {
logrus.Debugln("No encryptable tables in search query")
// Walk through WHERE clauses of a SELECT statements...
whereExprs, err := getWhereStatements(statement)
if err != nil {
logrus.WithError(err).Debugln("Failed to extract WHERE clauses")
return nil
}

// Now take a closer look at WHERE clauses of the statement. We need only expressions
// which are simple equality comparisons, like "WHERE column = value".
exprs := filter.filterComparisonExprs(statement, defaultTable, aliasedTables)
if len(exprs) == 0 {
logrus.Debugln("No eligible comparisons in search query")
return nil
}
// And among those expressions, not all may refer to columns with searchable encryption
// enabled for them. Leave only those expressions which are searchable.
searchableExprs := filter.filterComparisons(exprs, defaultTable, aliasedTables)
if len(exprs) == 0 {
logrus.Debugln("No searchable comparisons in search query")
return nil
}
return searchableExprs
}

func (filter *SearchableQueryFilter) filterInterestingTables(fromExp sqlparser.TableExprs) (*AliasedTableName, AliasToTableMap) {
// Not all SELECT statements refer to tables at all.
tables := GetTablesWithAliases(fromExp)
if len(tables) == 0 {
return nil, nil
}

var defaultTable *AliasedTableName
var defaultTableName string
// if query contains table without alias we need to detect default table
// if no, we can ignore default table and AliasToTableMap will be used to map ColName with encryptor_config
if hasTablesWithoutAliases(fromExp) {
var err error
defaultTableName, err = getFirstTableWithoutAlias(fromExp)
var searchableExprs []SearchableExprItem
for _, whereExpr := range whereExprs {
comparisonExprs, err := filter.filterColumnEqualComparisonExprs(whereExpr, tableExps)
if err != nil {
logrus.WithError(err).Debugln("Failed to find first table without alias")
return nil, nil
logrus.WithError(err).Debugln("Failed to extract comparison expressions")
return nil
}
searchableExprs = append(searchableExprs, comparisonExprs...)
}

// And even then, we can work only with tables that we have an encryption schema for.
var encryptableTables []*AliasedTableName

for _, table := range tables {
if defaultTableName == table.TableName.Name.ValueForConfig() {
defaultTable = table
}

if v := filter.schemaStore.GetTableSchema(table.TableName.Name.ValueForConfig()); v != nil {
encryptableTables = append(encryptableTables, table)
}
}
if len(encryptableTables) == 0 {
return nil, nil
}
return defaultTable, NewAliasToTableMapFromTables(encryptableTables)
return searchableExprs
}

func (filter *SearchableQueryFilter) filterTableExpressions(statement sqlparser.Statement) (sqlparser.TableExprs, error) {
Expand All @@ -143,57 +102,8 @@ func (filter *SearchableQueryFilter) filterTableExpressions(statement sqlparser.
}
}

func (filter *SearchableQueryFilter) filterComparisonExprs(statement sqlparser.Statement, defaultTable *AliasedTableName, aliasedTables AliasToTableMap) []*sqlparser.ComparisonExpr {
// Walk through WHERE clauses of a SELECT statements...
whereExprs, err := getWhereStatements(statement)
if err != nil {
logrus.WithError(err).Debugln("Failed to extract WHERE clauses")
return nil
}
// ...and find all eligible comparison expressions in them.
var exprs []*sqlparser.ComparisonExpr
for _, whereExpr := range whereExprs {
comparisonExprs, err := filter.getColumnEqualComparisonExprs(whereExpr, defaultTable, aliasedTables)
if err != nil {
logrus.WithError(err).Debugln("Failed to extract comparison expressions")
return nil
}
exprs = append(exprs, comparisonExprs...)
}
return exprs
}

func (filter *SearchableQueryFilter) filterComparisons(exprs []*sqlparser.ComparisonExpr, defaultTable *AliasedTableName, aliasedTables AliasToTableMap) []SearchableExprItem {
filtered := make([]SearchableExprItem, 0, len(exprs))
for _, expr := range exprs {
// Leave out comparisons of columns which do not have a schema after alias resolution.
column := expr.Left.(*sqlparser.ColName)
schema := filter.getTableSchemaOfColumn(column, defaultTable, aliasedTables)
if schema == nil {
continue
}
// Also leave out those columns which are not searchable.
columnName := column.Name.String()
encryptionSetting := schema.GetColumnEncryptionSettings(columnName)

if encryptionSetting == nil {
continue
}

isComparableSetting := encryptionSetting.IsSearchable()
if filter.mode == QueryFilterModeConsistentTokenization {
isComparableSetting = encryptionSetting.IsConsistentTokenization()
}

if isComparableSetting {
filtered = append(filtered, SearchableExprItem{Expr: expr, Setting: encryptionSetting})
}
}
return filtered
}

func (filter *SearchableQueryFilter) getColumnSetting(column *sqlparser.ColName, defaultTable *AliasedTableName, aliasedTables AliasToTableMap) config.ColumnEncryptionSetting {
schema := filter.getTableSchemaOfColumn(column, defaultTable, aliasedTables)
func (filter *SearchableQueryFilter) getColumnSetting(column *sqlparser.ColName, columnInfo columnInfo) config.ColumnEncryptionSetting {
schema := filter.schemaStore.GetTableSchema(columnInfo.Table)
if schema == nil {
return nil
}
Expand All @@ -202,14 +112,6 @@ func (filter *SearchableQueryFilter) getColumnSetting(column *sqlparser.ColName,
return schema.GetColumnEncryptionSettings(columnName)
}

func (filter *SearchableQueryFilter) getTableSchemaOfColumn(column *sqlparser.ColName, defaultTable *AliasedTableName, aliasedTables AliasToTableMap) config.TableSchema {
if column.Qualifier.Qualifier.IsEmpty() && column.Qualifier.Name.IsEmpty() {
return filter.schemaStore.GetTableSchema(defaultTable.TableName.Name.ValueForConfig())
}
tableName := aliasedTables[column.Qualifier.Name.ValueForConfig()]
return filter.schemaStore.GetTableSchema(tableName)
}

func getWhereStatements(stmt sqlparser.Statement) ([]*sqlparser.Where, error) {
var whereStatements []*sqlparser.Where
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
Expand All @@ -235,48 +137,72 @@ func isSupportedSQLVal(val *sqlparser.SQLVal) bool {
return false
}

// getColumnEqualComparisonExprs return only <ColName> = <VALUE> or <ColName> != <VALUE> or <ColName> <=> <VALUE> expressions
func (filter *SearchableQueryFilter) getColumnEqualComparisonExprs(stmt sqlparser.SQLNode, defaultTable *AliasedTableName, aliasedTables AliasToTableMap) ([]*sqlparser.ComparisonExpr, error) {
var exprs []*sqlparser.ComparisonExpr
// filterColumnEqualComparisonExprs return only <ColName> = <VALUE> or <ColName> != <VALUE> or <ColName> <=> <VALUE> expressions
func (filter *SearchableQueryFilter) filterColumnEqualComparisonExprs(stmt sqlparser.SQLNode, tableExpr sqlparser.TableExprs) ([]SearchableExprItem, error) {
var exprs []SearchableExprItem

err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
if comparisonExpr, ok := node.(*sqlparser.ComparisonExpr); ok {
lColumn, ok := comparisonExpr.Left.(*sqlparser.ColName)
if !ok {
return true, nil
}
comparisonExpr, ok := node.(*sqlparser.ComparisonExpr)
if !ok {
return true, nil
}

lColumn, ok := comparisonExpr.Left.(*sqlparser.ColName)
if !ok {
return true, nil
}

lColumnSetting := filter.getColumnSetting(lColumn, defaultTable, aliasedTables)
if lColumnSetting == nil {
columnInfo, err := findColumnInfo(tableExpr, lColumn, filter.schemaStore)
if err != nil {
return true, nil
}

lColumnSetting := filter.getColumnSetting(lColumn, columnInfo)
if lColumnSetting == nil {
return true, nil
}

if !lColumnSetting.IsSearchable() && !lColumnSetting.IsConsistentTokenization() {
return true, nil
}

// check if left column isSearchable or consistent tokenized and right column is sqlparser.ColName
// we want to log the warn message that searchable tokenization/encryption can work only with <ColName> = <VALUE> statements
// however, there is one exception - for searchable encryption it can be the scenario where we have: join table1 t1 on t1.surname = t2.surname
// and if t1 and t2 are tables from encryptor_config and t1.surname and t2.surname are searchable, we want to have: join table1 t1 on substr(t1.surname, ...) = substr(t2.surname, ...)
if rColumn, ok := comparisonExpr.Right.(*sqlparser.ColName); ok {
// get right columnSetting to check weather it is searchable too

columnInfo, err := findColumnInfo(tableExpr, rColumn, filter.schemaStore)
if err != nil {
return true, nil
}

// check if left column isSearchable or consistent tokenized and right column is sqlparser.ColName
// we want to log the warn message that searchable tokenization/encryption can work only with <ColName> = <VALUE> statements
// however, there is one exception - for searchable encryption it can be the scenario where we have: join table1 t1 on t1.surname = t2.surname
// and if t1 and t2 are tables from encryptor_config and t1.surname and t2.surname are searchable, we want to have: join table1 t1 on substr(t1.surname, ...) = substr(t2.surname, ...)
if lColumnSetting.IsSearchable() || lColumnSetting.IsConsistentTokenization() {
if rColumn, ok := comparisonExpr.Right.(*sqlparser.ColName); ok {
// get right columnSetting to check weather it is searchable too
rColumnSetting := filter.getColumnSetting(rColumn, defaultTable, aliasedTables)
if rColumnSetting != nil {
if rColumnSetting.IsSearchable() {
exprs = append(exprs, comparisonExpr)
return true, nil
}
}

logrus.Infoln("Searchable encryption/tokenization support equal comparison only by SQLVal but not by ColName")
rColumnSetting := filter.getColumnSetting(rColumn, columnInfo)
if rColumnSetting != nil {
if rColumnSetting.IsSearchable() {
exprs = append(exprs, SearchableExprItem{
Expr: comparisonExpr,
Setting: rColumnSetting,
})
return true, nil
}
}

if sqlVal, ok := comparisonExpr.Right.(*sqlparser.SQLVal); ok && isSupportedSQLVal(sqlVal) {
if comparisonExpr.Operator == sqlparser.EqualStr || comparisonExpr.Operator == sqlparser.NotEqualStr || comparisonExpr.Operator == sqlparser.NullSafeEqualStr {
if _, ok := comparisonExpr.Left.(*sqlparser.ColName); ok {
exprs = append(exprs, comparisonExpr)
}
logrus.Infoln("Searchable encryption/tokenization support equal comparison only by SQLVal but not by ColName")
}

if sqlVal, ok := comparisonExpr.Right.(*sqlparser.SQLVal); ok && isSupportedSQLVal(sqlVal) {
if comparisonExpr.Operator == sqlparser.EqualStr || comparisonExpr.Operator == sqlparser.NotEqualStr || comparisonExpr.Operator == sqlparser.NullSafeEqualStr {
if _, ok := comparisonExpr.Left.(*sqlparser.ColName); ok {
exprs = append(exprs, SearchableExprItem{
Expr: comparisonExpr,
Setting: lColumnSetting,
})
}
}
}

return true, nil
}, stmt)
return exprs, err
Expand Down
55 changes: 17 additions & 38 deletions encryptor/searchable_query_filter_test.go
Expand Up @@ -5,13 +5,16 @@ import (

"github.com/cossacklabs/acra/encryptor/config"
"github.com/cossacklabs/acra/sqlparser"
"github.com/cossacklabs/acra/sqlparser/dialect/postgresql"
)

func TestGetTableSchemaOfColumnMatchConfigTable(t *testing.T) {
tableNameUpperCase := "SomeTableInUpperCase"
configStr := `
schemas:
- table: sometableinuppercase
columns:
- default_client_id
- specified_client_id
encrypted:
- column: "default_client_id"
- column: specified_client_id
Expand All @@ -22,53 +25,29 @@ schemas:
t.Fatalf("Can't parse config: %s", err.Error())
}

searchableQueryFilter := SearchableQueryFilter{
schemaStore: schemaStore,
}
query := `SELECT * from SomeTableInUpperCase WHERE default_client_id = 'value'`

tableNamesWithQuotes := sqlparser.NewTableIdentWithQuotes(tableNameUpperCase, '"')
schemaTable := searchableQueryFilter.getTableSchemaOfColumn(&sqlparser.ColName{}, &AliasedTableName{
TableName: sqlparser.TableName{
Name: tableNamesWithQuotes,
},
}, AliasToTableMap{})

if schemaTable == nil {
t.Fatalf("Expect not nil schemaTable, matched with config")
stmt, err := sqlparser.ParseWithDialect(postgresql.NewPostgreSQLDialect(), query)
if err != nil {
t.Fatalf("Can't parse query statement: %s", err.Error())
}
}

func TestFilterInterestingTables(t *testing.T) {
tableNameUpperCase := "SomeTableInUpperCase"
configStr := `
schemas:
- table: sometableinuppercase
encrypted:
- column: "default_client_id"
- column: specified_client_id
client_id: specified_client_id
`
schemaStore, err := config.MapTableSchemaStoreFromConfig([]byte(configStr), config.UseMySQL)
selectQuery := stmt.(*sqlparser.Select)
columnInfo, err := findColumnInfo(selectQuery.From, selectQuery.Where.Expr.(*sqlparser.ComparisonExpr).Left.(*sqlparser.ColName), schemaStore)
if err != nil {
t.Fatalf("Can't parse config: %s", err.Error())
t.Fatalf("Can't find column info: %s", err.Error())
}

searchableQueryFilter := SearchableQueryFilter{
schemaStore: schemaStore,
}

tableNamesWithQuotes := sqlparser.NewTableIdentWithQuotes(tableNameUpperCase, '"')

aliasedTable, _ := searchableQueryFilter.filterInterestingTables(sqlparser.TableExprs{
&sqlparser.AliasedTableExpr{
Expr: sqlparser.TableName{
Name: tableNamesWithQuotes,
},
},
})
schemaTable := searchableQueryFilter.getColumnSetting(&sqlparser.ColName{
Name: sqlparser.NewColIdent("default_client_id"),
}, columnInfo)

if aliasedTable == nil {
t.Fatalf("Expect not nil aliasedTable, matched with config")
if schemaTable == nil {
t.Fatalf("Expect not nil schemaTable, matched with config")
}
}

Expand All @@ -93,7 +72,7 @@ func Test_getColumnEqualComparisonExprs_NotColumnComparisonQueries(t *testing.T)
t.Fatalf("expected no error on parsing valid WHERE clause query - %s", err.Error())
}

compExprs, err := searchableQueryFilter.getColumnEqualComparisonExprs(whereStatements[0], nil, nil)
compExprs, err := searchableQueryFilter.filterColumnEqualComparisonExprs(whereStatements[0], statement.(*sqlparser.Select).From)
if err != nil {
t.Fatal(err)
}
Expand Down