Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions rust/ql/lib/codeql/rust/security/SqlInjectionExtensions.qll
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ private import codeql.rust.dataflow.DataFlow
private import codeql.rust.dataflow.FlowSink
private import codeql.rust.Concepts
private import codeql.util.Unit
private import codeql.rust.controlflow.ControlFlowGraph as Cfg
private import codeql.rust.controlflow.CfgNodes as CfgNodes

/**
* Provides default sources, sinks and barriers for detecting SQL injection
Expand Down Expand Up @@ -57,4 +59,157 @@ module SqlInjection {
private class ModelsAsDataSink extends Sink {
ModelsAsDataSink() { sinkNode(this, "sql-injection") }
}

/**
* A barrier for SQL injection vulnerabilities for nodes that result from parsing to numeric types.
*
* Numeric types are considered safe because they cannot contain SQL injection payloads.
* This barrier stops taint flow when untrusted data is parsed to a numeric type like i32, u32, f64, etc.
*/
private class NumericTypeBarrier extends Barrier {
NumericTypeBarrier() {
// Match dataflow nodes that come from the result of parsing strings to numeric types
// The parse method is called with a turbofish operator specifying the target type: .parse::<i32>()
exists(MethodCallExpr parse, string typeName |
parse.getIdentifier().getText() = "parse" and
this.asExpr().getExpr() = parse and
// Extract the type name from the generic argument list
typeName = parse.getGenericArgList().toString() and
// Check if it's a numeric type
(
typeName.matches("%i8%") or
typeName.matches("%i16%") or
typeName.matches("%i32%") or
typeName.matches("%i64%") or
typeName.matches("%i128%") or
typeName.matches("%isize%") or
typeName.matches("%u8%") or
typeName.matches("%u16%") or
typeName.matches("%u32%") or
typeName.matches("%u64%") or
typeName.matches("%u128%") or
typeName.matches("%usize%") or
typeName.matches("%f32%") or
typeName.matches("%f64%")
)
)
}
}

/**
* Holds if comparison `guard` validates `node` by comparing it with a string literal
* when `branch` is true.
*/
private predicate stringConstCompare(CfgNodes::AstCfgNode guard, Cfg::CfgNode node, boolean branch) {
exists(EqualsOperation eq |
guard = eq.getACfgNode() and
branch = true and
(
// node == "literal" or "literal" == node
node = eq.getLhs().getACfgNode() and
eq.getRhs() instanceof StringLiteralExpr
or
node = eq.getRhs().getACfgNode() and
eq.getLhs() instanceof StringLiteralExpr
)
)
or
exists(NotEqualsOperation ne |
guard = ne.getACfgNode() and
branch = false and
(
// node != "literal" or "literal" != node
node = ne.getLhs().getACfgNode() and
ne.getRhs() instanceof StringLiteralExpr
or
node = ne.getRhs().getACfgNode() and
ne.getLhs() instanceof StringLiteralExpr
)
)
or
// Handle multiple comparisons with OR
stringConstCompareOr(guard, node, branch)
}

/**
* Holds if `guard` is an OR expression where both operands compare `node`
* with string literals when `branch` is true.
*/
private predicate stringConstCompareOr(
CfgNodes::AstCfgNode guard, Cfg::CfgNode node, boolean branch
) {
exists(LogicalOrExpr orExpr, EqualsOperation eqLeft, EqualsOperation eqRight |
guard = orExpr.getACfgNode() and
branch = true and
eqLeft.getACfgNode() = orExpr.getLhs().getACfgNode() and
eqRight.getACfgNode() = orExpr.getRhs().getACfgNode() and
// Both sides must compare the same node against string literals
(
node = eqLeft.getLhs().getACfgNode() and
node = eqRight.getLhs().getACfgNode() and
eqLeft.getRhs() instanceof StringLiteralExpr and
eqRight.getRhs() instanceof StringLiteralExpr
or
node = eqLeft.getRhs().getACfgNode() and
node = eqRight.getRhs().getACfgNode() and
eqLeft.getLhs() instanceof StringLiteralExpr and
eqRight.getLhs() instanceof StringLiteralExpr
)
)
}

/**
* A barrier for SQL injection vulnerabilities where the data is validated
* against one or more constant string values.
*
* For example, `if remote_string == "admin"` or `if remote_string == "person" || remote_string == "vehicle"`.
*/
private class StringConstCompareBarrier extends Barrier {
StringConstCompareBarrier() {
this = DataFlow::BarrierGuard<stringConstCompare/3>::getABarrierNode()
}
}

/**
* Holds if `guard` is a call to a collection contains/includes method that checks
* if `node` is in a collection of string literals.
*/
private predicate stringConstArrayInclusionCall(
CfgNodes::AstCfgNode guard, Cfg::CfgNode node, boolean branch
) {
exists(MethodCallExpr call |
guard = call.getACfgNode() and
branch = true and
// Check for contains() method call
call.getIdentifier().getText() = "contains" and
// The argument should be the node we're checking
node = call.getArgList().getAnArg().getACfgNode() and
// The receiver should be an array/slice/collection of string literals
exists(Expr receiver |
receiver = call.getReceiver() and
isArrayOfStringLiterals(receiver)
)
)
}

/**
* Holds if `expr` is an array or slice literal containing only string literals.
*/
private predicate isArrayOfStringLiterals(Expr expr) {
// Array literal: ["str1", "str2", "str3"]
expr instanceof ArrayListExpr and
forex(Expr elem | elem = expr.(ArrayListExpr).getExpr(_) | elem instanceof StringLiteralExpr)
}

/**
* A barrier for SQL injection where the data is validated by checking
* if it's contained in a collection of constant string values.
*
* For example, `if ["admin", "user"].contains(&remote_string)`.
*/
private class StringConstArrayInclusionCallBarrier extends Barrier {
StringConstArrayInclusionCallBarrier() {
this = DataFlow::BarrierGuard<stringConstArrayInclusionCall/3>::getABarrierNode()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
category: minorAnalysis
---
* The `rust/sql-injection` query now includes taint flow barriers to reduce false positives. Specifically:
* Data parsed to numeric types (e.g., `.parse::<i32>()`) is now recognized as safe.
* Data validated against one or more constant string values (e.g., `if x == "admin"` or `if x == "user" || x == "guest"`) is now recognized as safe within the validated branch.
* Data validated using collection membership checks against string literals (e.g., `if ["admin", "user"].contains(&x)`) is now recognized as safe within the validated branch.
148 changes: 148 additions & 0 deletions rust/ql/test/query-tests/security/CWE-089/barriers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
use sqlx::Connection;
use sqlx::Executor;

/**
* Test cases for SQL injection barriers/sanitizers.
* These test that proper validation blocks taint flow.
*/

async fn test_barriers(enable_remote: bool) -> Result<(), sqlx::Error> {
let pool: sqlx::Pool<sqlx::MySql> = sqlx::mysql::MySqlPool::connect("").await?;

// Get remote data (untrusted source)
let remote_string = reqwest::blocking::get("http://example.com/").unwrap().text().unwrap_or(String::from("Alice")); // $ Source=remote_barrier

// --- Barrier 1: Numeric type sanitization ---
// When untrusted data is parsed to a numeric type, it should be safe
let remote_number = remote_string.parse::<i32>().unwrap_or(0);
let safe_numeric_query = format!("SELECT * FROM people WHERE id={remote_number}");
let _ = sqlx::query(safe_numeric_query.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier

// Also test with other numeric types
let remote_u32 = remote_string.parse::<u32>().unwrap_or(0);
let safe_u32_query = format!("SELECT * FROM people WHERE id={remote_u32}");
let _ = sqlx::query(safe_u32_query.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier

let remote_i64 = remote_string.parse::<i64>().unwrap_or(0);
let safe_i64_query = format!("SELECT * FROM people WHERE id={remote_i64}");
let _ = sqlx::query(safe_i64_query.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier

let remote_f64 = remote_string.parse::<f64>().unwrap_or(0.0);
let safe_f64_query = format!("SELECT * FROM people WHERE price={remote_f64}");
let _ = sqlx::query(safe_f64_query.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier

// --- Barrier 2: Single constant comparison ---
// When untrusted data is compared with a single constant value
if enable_remote {
let remote_string2 = reqwest::blocking::get("http://example.com/").unwrap().text().unwrap_or(String::from("Alice")); // $ Source=remote_barrier2

// Safe: validated against single constant
if remote_string2 == "admin" {
let safe_single_const = format!("SELECT * FROM people WHERE role='{remote_string2}'");
let _ = sqlx::query(safe_single_const.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier2
}

// Safe: validated with != (false branch is safe)
if remote_string2 != "admin" {
// unsafe branch - still tainted
let unsafe_ne_query = format!("SELECT * FROM people WHERE role='{remote_string2}'");
let _ = sqlx::query(unsafe_ne_query.as_str()).execute(&pool).await?; // $ sql-sink Alert[rust/sql-injection]=remote_barrier2
} else {
// safe branch - validated to be "admin"
let safe_ne_query = format!("SELECT * FROM people WHERE role='{remote_string2}'");
let _ = sqlx::query(safe_ne_query.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier2
}
}

// --- Barrier 3: Multiple constant comparison (OR pattern) ---
// When untrusted data is compared with multiple constant values
if enable_remote {
let remote_string3 = reqwest::blocking::get("http://example.com/").unwrap().text().unwrap_or(String::from("Alice")); // $ Source=remote_barrier3

// Safe: validated against multiple constants with OR
if remote_string3 == "person" || remote_string3 == "vehicle" {
let safe_multi_const = format!("SELECT * FROM entities WHERE type='{remote_string3}'");
let _ = sqlx::query(safe_multi_const.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier3
}

// Safe: validated against multiple constants with OR (more than 2)
if remote_string3 == "alice" || remote_string3 == "bob" || remote_string3 == "charlie" {
let safe_multi_const_3 = format!("SELECT * FROM people WHERE name='{remote_string3}'");
let _ = sqlx::query(safe_multi_const_3.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier3
}
}

// --- Barrier 4: Collection/array constant comparison ---
// When untrusted data is checked against a collection of constant values
if enable_remote {
let remote_string4 = reqwest::blocking::get("http://example.com/").unwrap().text().unwrap_or(String::from("Alice")); // $ Source=remote_barrier4

// Safe: validated against an array of constants
let allowed_roles = vec!["admin", "user", "guest"];
if allowed_roles.contains(&remote_string4.as_str()) {
let safe_array_check = format!("SELECT * FROM people WHERE role='{remote_string4}'");
let _ = sqlx::query(safe_array_check.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier4
}

// Safe: validated with slice contains
if ["manager", "employee", "contractor"].contains(&remote_string4.as_str()) {
let safe_slice_check = format!("SELECT * FROM people WHERE role='{remote_string4}'");
let _ = sqlx::query(safe_slice_check.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier4
}

// Safe: validated with HashSet
use std::collections::HashSet;
let mut allowed_set = HashSet::new();
allowed_set.insert("alice");
allowed_set.insert("bob");
allowed_set.insert("charlie");
if allowed_set.contains(&remote_string4.as_str()) {
let safe_set_check = format!("SELECT * FROM people WHERE name='{remote_string4}'");
let _ = sqlx::query(safe_set_check.as_str()).execute(&pool).await?; // $ sql-sink SPURIOUS: Alert[rust/sql-injection]=remote_barrier4
}
}

// --- Negative test cases: Incorrect sanitization ---

if enable_remote {
let remote_string5 = reqwest::blocking::get("http://example.com/").unwrap().text().unwrap_or(String::from("Alice")); // $ Source=remote_barrier5

// Unsafe: comparison with non-constant value (another variable)
let other_string = String::from("admin");
if remote_string5 == other_string {
let unsafe_non_const = format!("SELECT * FROM people WHERE role='{remote_string5}'");
let _ = sqlx::query(unsafe_non_const.as_str()).execute(&pool).await?; // $ sql-sink Alert[rust/sql-injection]=remote_barrier5
}

// Unsafe: comparison with OR where one side is not validated
if remote_string5 == "admin" || remote_string5.len() > 5 {
let unsafe_mixed_or = format!("SELECT * FROM people WHERE role='{remote_string5}'");
let _ = sqlx::query(unsafe_mixed_or.as_str()).execute(&pool).await?; // $ sql-sink Alert[rust/sql-injection]=remote_barrier5
}

// Unsafe: array contains with non-constant elements
let dynamic_vec = vec![other_string.as_str(), "user"];
if dynamic_vec.contains(&remote_string5.as_str()) {
let unsafe_dynamic_array = format!("SELECT * FROM people WHERE role='{remote_string5}'");
let _ = sqlx::query(unsafe_dynamic_array.as_str()).execute(&pool).await?; // $ sql-sink Alert[rust/sql-injection]=remote_barrier5
}

// Unsafe: checking length or other properties (not value equality)
if remote_string5.len() == 5 {
let unsafe_length_check = format!("SELECT * FROM people WHERE name='{remote_string5}'");
let _ = sqlx::query(unsafe_length_check.as_str()).execute(&pool).await?; // $ sql-sink Alert[rust/sql-injection]=remote_barrier5
}
}

Ok(())
}

fn main() {
println!("--- CWE-089 barriers.rs test ---");
let enable_remote = std::env::args().nth(1) == Some(String::from("ENABLE_REMOTE"));

match futures::executor::block_on(test_barriers(enable_remote)) {
Ok(_) => println!(" successful!"),
Err(e) => println!(" error: {}", e),
}
}