diff --git a/rust/ql/lib/codeql/rust/security/SqlInjectionExtensions.qll b/rust/ql/lib/codeql/rust/security/SqlInjectionExtensions.qll index f2921ef0cc13..99d79642e78e 100644 --- a/rust/ql/lib/codeql/rust/security/SqlInjectionExtensions.qll +++ b/rust/ql/lib/codeql/rust/security/SqlInjectionExtensions.qll @@ -9,6 +9,8 @@ private import codeql.rust.dataflow.DataFlow private import codeql.rust.dataflow.FlowSink private import codeql.rust.Concepts private import codeql.util.Unit +private import codeql.rust.controlflow.ControlFlowGraph as Cfg +private import codeql.rust.controlflow.CfgNodes as CfgNodes /** * Provides default sources, sinks and barriers for detecting SQL injection @@ -57,4 +59,157 @@ module SqlInjection { private class ModelsAsDataSink extends Sink { ModelsAsDataSink() { sinkNode(this, "sql-injection") } } + + /** + * A barrier for SQL injection vulnerabilities for nodes that result from parsing to numeric types. + * + * Numeric types are considered safe because they cannot contain SQL injection payloads. + * This barrier stops taint flow when untrusted data is parsed to a numeric type like i32, u32, f64, etc. + */ + private class NumericTypeBarrier extends Barrier { + NumericTypeBarrier() { + // Match dataflow nodes that come from the result of parsing strings to numeric types + // The parse method is called with a turbofish operator specifying the target type: .parse::() + exists(MethodCallExpr parse, string typeName | + parse.getIdentifier().getText() = "parse" and + this.asExpr().getExpr() = parse and + // Extract the type name from the generic argument list + typeName = parse.getGenericArgList().toString() and + // Check if it's a numeric type + ( + typeName.matches("%i8%") or + typeName.matches("%i16%") or + typeName.matches("%i32%") or + typeName.matches("%i64%") or + typeName.matches("%i128%") or + typeName.matches("%isize%") or + typeName.matches("%u8%") or + typeName.matches("%u16%") or + typeName.matches("%u32%") or + typeName.matches("%u64%") or + typeName.matches("%u128%") or + typeName.matches("%usize%") or + typeName.matches("%f32%") or + typeName.matches("%f64%") + ) + ) + } + } + + /** + * Holds if comparison `guard` validates `node` by comparing it with a string literal + * when `branch` is true. + */ + private predicate stringConstCompare(CfgNodes::AstCfgNode guard, Cfg::CfgNode node, boolean branch) { + exists(EqualsOperation eq | + guard = eq.getACfgNode() and + branch = true and + ( + // node == "literal" or "literal" == node + node = eq.getLhs().getACfgNode() and + eq.getRhs() instanceof StringLiteralExpr + or + node = eq.getRhs().getACfgNode() and + eq.getLhs() instanceof StringLiteralExpr + ) + ) + or + exists(NotEqualsOperation ne | + guard = ne.getACfgNode() and + branch = false and + ( + // node != "literal" or "literal" != node + node = ne.getLhs().getACfgNode() and + ne.getRhs() instanceof StringLiteralExpr + or + node = ne.getRhs().getACfgNode() and + ne.getLhs() instanceof StringLiteralExpr + ) + ) + or + // Handle multiple comparisons with OR + stringConstCompareOr(guard, node, branch) + } + + /** + * Holds if `guard` is an OR expression where both operands compare `node` + * with string literals when `branch` is true. + */ + private predicate stringConstCompareOr( + CfgNodes::AstCfgNode guard, Cfg::CfgNode node, boolean branch + ) { + exists(LogicalOrExpr orExpr, EqualsOperation eqLeft, EqualsOperation eqRight | + guard = orExpr.getACfgNode() and + branch = true and + eqLeft.getACfgNode() = orExpr.getLhs().getACfgNode() and + eqRight.getACfgNode() = orExpr.getRhs().getACfgNode() and + // Both sides must compare the same node against string literals + ( + node = eqLeft.getLhs().getACfgNode() and + node = eqRight.getLhs().getACfgNode() and + eqLeft.getRhs() instanceof StringLiteralExpr and + eqRight.getRhs() instanceof StringLiteralExpr + or + node = eqLeft.getRhs().getACfgNode() and + node = eqRight.getRhs().getACfgNode() and + eqLeft.getLhs() instanceof StringLiteralExpr and + eqRight.getLhs() instanceof StringLiteralExpr + ) + ) + } + + /** + * A barrier for SQL injection vulnerabilities where the data is validated + * against one or more constant string values. + * + * For example, `if remote_string == "admin"` or `if remote_string == "person" || remote_string == "vehicle"`. + */ + private class StringConstCompareBarrier extends Barrier { + StringConstCompareBarrier() { + this = DataFlow::BarrierGuard::getABarrierNode() + } + } + + /** + * Holds if `guard` is a call to a collection contains/includes method that checks + * if `node` is in a collection of string literals. + */ + private predicate stringConstArrayInclusionCall( + CfgNodes::AstCfgNode guard, Cfg::CfgNode node, boolean branch + ) { + exists(MethodCallExpr call | + guard = call.getACfgNode() and + branch = true and + // Check for contains() method call + call.getIdentifier().getText() = "contains" and + // The argument should be the node we're checking + node = call.getArgList().getAnArg().getACfgNode() and + // The receiver should be an array/slice/collection of string literals + exists(Expr receiver | + receiver = call.getReceiver() and + isArrayOfStringLiterals(receiver) + ) + ) + } + + /** + * Holds if `expr` is an array or slice literal containing only string literals. + */ + private predicate isArrayOfStringLiterals(Expr expr) { + // Array literal: ["str1", "str2", "str3"] + expr instanceof ArrayListExpr and + forex(Expr elem | elem = expr.(ArrayListExpr).getExpr(_) | elem instanceof StringLiteralExpr) + } + + /** + * A barrier for SQL injection where the data is validated by checking + * if it's contained in a collection of constant string values. + * + * For example, `if ["admin", "user"].contains(&remote_string)`. + */ + private class StringConstArrayInclusionCallBarrier extends Barrier { + StringConstArrayInclusionCallBarrier() { + this = DataFlow::BarrierGuard::getABarrierNode() + } + } } diff --git a/rust/ql/src/change-notes/2025-10-30-sql-injection-barriers.md b/rust/ql/src/change-notes/2025-10-30-sql-injection-barriers.md new file mode 100644 index 000000000000..16d7a85a6f3e --- /dev/null +++ b/rust/ql/src/change-notes/2025-10-30-sql-injection-barriers.md @@ -0,0 +1,7 @@ +--- +category: minorAnalysis +--- +* The `rust/sql-injection` query now includes taint flow barriers to reduce false positives. Specifically: + * Data parsed to numeric types (e.g., `.parse::()`) is now recognized as safe. + * Data validated against one or more constant string values (e.g., `if x == "admin"` or `if x == "user" || x == "guest"`) is now recognized as safe within the validated branch. + * Data validated using collection membership checks against string literals (e.g., `if ["admin", "user"].contains(&x)`) is now recognized as safe within the validated branch. diff --git a/rust/ql/test/query-tests/security/CWE-089/barriers.rs b/rust/ql/test/query-tests/security/CWE-089/barriers.rs new file mode 100644 index 000000000000..5171917b71f4 --- /dev/null +++ b/rust/ql/test/query-tests/security/CWE-089/barriers.rs @@ -0,0 +1,148 @@ +use sqlx::Connection; +use sqlx::Executor; + +/** + * Test cases for SQL injection barriers/sanitizers. + * These test that proper validation blocks taint flow. + */ + +async fn test_barriers(enable_remote: bool) -> Result<(), sqlx::Error> { + let pool: sqlx::Pool = sqlx::mysql::MySqlPool::connect("").await?; + + // Get remote data (untrusted source) + let remote_string = reqwest::blocking::get("http://example.com/").unwrap().text().unwrap_or(String::from("Alice")); // $ Source=remote_barrier + + // --- Barrier 1: Numeric type sanitization --- + // When untrusted data is parsed to a numeric type, it should be safe + let remote_number = remote_string.parse::().unwrap_or(0); + let safe_numeric_query = format!("SELECT * FROM people WHERE id={remote_number}"); + let _ = sqlx::query(safe_numeric_query.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier + + // Also test with other numeric types + let remote_u32 = remote_string.parse::().unwrap_or(0); + let safe_u32_query = format!("SELECT * FROM people WHERE id={remote_u32}"); + let _ = sqlx::query(safe_u32_query.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier + + let remote_i64 = remote_string.parse::().unwrap_or(0); + let safe_i64_query = format!("SELECT * FROM people WHERE id={remote_i64}"); + let _ = sqlx::query(safe_i64_query.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier + + let remote_f64 = remote_string.parse::().unwrap_or(0.0); + let safe_f64_query = format!("SELECT * FROM people WHERE price={remote_f64}"); + let _ = sqlx::query(safe_f64_query.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier + + // --- Barrier 2: Single constant comparison --- + // When untrusted data is compared with a single constant value + if enable_remote { + let remote_string2 = reqwest::blocking::get("http://example.com/").unwrap().text().unwrap_or(String::from("Alice")); // $ Source=remote_barrier2 + + // Safe: validated against single constant + if remote_string2 == "admin" { + let safe_single_const = format!("SELECT * FROM people WHERE role='{remote_string2}'"); + let _ = sqlx::query(safe_single_const.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier2 + } + + // Safe: validated with != (false branch is safe) + if remote_string2 != "admin" { + // unsafe branch - still tainted + let unsafe_ne_query = format!("SELECT * FROM people WHERE role='{remote_string2}'"); + let _ = sqlx::query(unsafe_ne_query.as_str()).execute(&pool).await?; // $ sql-sink Alert[rust/sql-injection]=remote_barrier2 + } else { + // safe branch - validated to be "admin" + let safe_ne_query = format!("SELECT * FROM people WHERE role='{remote_string2}'"); + let _ = sqlx::query(safe_ne_query.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier2 + } + } + + // --- Barrier 3: Multiple constant comparison (OR pattern) --- + // When untrusted data is compared with multiple constant values + if enable_remote { + let remote_string3 = reqwest::blocking::get("http://example.com/").unwrap().text().unwrap_or(String::from("Alice")); // $ Source=remote_barrier3 + + // Safe: validated against multiple constants with OR + if remote_string3 == "person" || remote_string3 == "vehicle" { + let safe_multi_const = format!("SELECT * FROM entities WHERE type='{remote_string3}'"); + let _ = sqlx::query(safe_multi_const.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier3 + } + + // Safe: validated against multiple constants with OR (more than 2) + if remote_string3 == "alice" || remote_string3 == "bob" || remote_string3 == "charlie" { + let safe_multi_const_3 = format!("SELECT * FROM people WHERE name='{remote_string3}'"); + let _ = sqlx::query(safe_multi_const_3.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier3 + } + } + + // --- Barrier 4: Collection/array constant comparison --- + // When untrusted data is checked against a collection of constant values + if enable_remote { + let remote_string4 = reqwest::blocking::get("http://example.com/").unwrap().text().unwrap_or(String::from("Alice")); // $ Source=remote_barrier4 + + // Safe: validated against an array of constants + let allowed_roles = vec!["admin", "user", "guest"]; + if allowed_roles.contains(&remote_string4.as_str()) { + let safe_array_check = format!("SELECT * FROM people WHERE role='{remote_string4}'"); + let _ = sqlx::query(safe_array_check.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier4 + } + + // Safe: validated with slice contains + if ["manager", "employee", "contractor"].contains(&remote_string4.as_str()) { + let safe_slice_check = format!("SELECT * FROM people WHERE role='{remote_string4}'"); + let _ = sqlx::query(safe_slice_check.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier4 + } + + // Safe: validated with HashSet + use std::collections::HashSet; + let mut allowed_set = HashSet::new(); + allowed_set.insert("alice"); + allowed_set.insert("bob"); + allowed_set.insert("charlie"); + if allowed_set.contains(&remote_string4.as_str()) { + let safe_set_check = format!("SELECT * FROM people WHERE name='{remote_string4}'"); + let _ = sqlx::query(safe_set_check.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier4 + } + } + + // --- Negative test cases: Incorrect sanitization --- + + if enable_remote { + let remote_string5 = reqwest::blocking::get("http://example.com/").unwrap().text().unwrap_or(String::from("Alice")); // $ Source=remote_barrier5 + + // Unsafe: comparison with non-constant value (another variable) + let other_string = String::from("admin"); + if remote_string5 == other_string { + let unsafe_non_const = format!("SELECT * FROM people WHERE role='{remote_string5}'"); + let _ = sqlx::query(unsafe_non_const.as_str()).execute(&pool).await?; // $ sql-sink Alert[rust/sql-injection]=remote_barrier5 + } + + // Unsafe: comparison with OR where one side is not validated + if remote_string5 == "admin" || remote_string5.len() > 5 { + let unsafe_mixed_or = format!("SELECT * FROM people WHERE role='{remote_string5}'"); + let _ = sqlx::query(unsafe_mixed_or.as_str()).execute(&pool).await?; // $ sql-sink Alert[rust/sql-injection]=remote_barrier5 + } + + // Unsafe: array contains with non-constant elements + let dynamic_vec = vec![other_string.as_str(), "user"]; + if dynamic_vec.contains(&remote_string5.as_str()) { + let unsafe_dynamic_array = format!("SELECT * FROM people WHERE role='{remote_string5}'"); + let _ = sqlx::query(unsafe_dynamic_array.as_str()).execute(&pool).await?; // $ sql-sink Alert[rust/sql-injection]=remote_barrier5 + } + + // Unsafe: checking length or other properties (not value equality) + if remote_string5.len() == 5 { + let unsafe_length_check = format!("SELECT * FROM people WHERE name='{remote_string5}'"); + let _ = sqlx::query(unsafe_length_check.as_str()).execute(&pool).await?; // $ sql-sink Alert[rust/sql-injection]=remote_barrier5 + } + } + + Ok(()) +} + +fn main() { + println!("--- CWE-089 barriers.rs test ---"); + let enable_remote = std::env::args().nth(1) == Some(String::from("ENABLE_REMOTE")); + + match futures::executor::block_on(test_barriers(enable_remote)) { + Ok(_) => println!(" successful!"), + Err(e) => println!(" error: {}", e), + } +}