Skip to content

Commit

Permalink
fix: make begin,commit,rollback cancel-safe in sqlite (launchbadge#2054)
Browse files Browse the repository at this point in the history
  • Loading branch information
madadam committed Aug 17, 2022
1 parent 26f60d9 commit acc3801
Showing 1 changed file with 155 additions and 10 deletions.
165 changes: 155 additions & 10 deletions sqlx-core/src/sqlite/connection/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ pub(crate) struct ConnectionWorker {
pub(crate) handle_raw: ConnectionHandleRaw,
/// Mutex for locking access to the database.
pub(crate) shared: Arc<WorkerSharedState>,

// Mirror of `shared.conn.transaction_depth` to help provide cancel-safety:
//
// - If `transaction_depth == shared.conn.transaction_depth` then no cancellation occurred
// - If `transaction_depth == shared.conn.transaction_depth - 1` then a `begin()` was cancelled
// - If `transaction_depth == shared.conn.transaction_depth + 1` then a `commit()` or
// `rollback()` was cancelled
// - No other cases are possible (would indicate a logic bug)
transaction_depth: usize,
}

pub(crate) struct WorkerSharedState {
Expand All @@ -52,15 +61,19 @@ enum Command {
query: Box<str>,
arguments: Option<SqliteArguments<'static>>,
persistent: bool,
transaction_depth: usize,
tx: flume::Sender<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
},
Begin {
transaction_depth: usize,
tx: oneshot::Sender<Result<(), Error>>,
},
Commit {
transaction_depth: usize,
tx: oneshot::Sender<Result<(), Error>>,
},
Rollback {
transaction_depth: usize,
tx: Option<oneshot::Sender<Result<(), Error>>>,
},
CreateCollation {
Expand Down Expand Up @@ -110,6 +123,7 @@ impl ConnectionWorker {
command_tx,
handle_raw: conn.handle.to_raw(),
shared: Arc::clone(&shared),
transaction_depth: 0,
}))
.is_err()
{
Expand All @@ -135,8 +149,15 @@ impl ConnectionWorker {
query,
arguments,
persistent,
transaction_depth,
tx,
} => {
if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth)
{
tx.send(Err(error)).ok();
continue;
}

let iter = match execute::iter(&mut conn, &query, arguments, persistent)
{
Ok(iter) => iter,
Expand All @@ -154,7 +175,16 @@ impl ConnectionWorker {

update_cached_statements_size(&conn, &shared.cached_statements_size);
}
Command::Begin { tx } => {
Command::Begin {
transaction_depth,
tx,
} => {
if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth)
{
tx.send(Err(error)).ok();
continue;
}

let depth = conn.transaction_depth;
let res =
conn.handle
Expand All @@ -165,9 +195,17 @@ impl ConnectionWorker {

tx.send(res).ok();
}
Command::Commit { tx } => {
let depth = conn.transaction_depth;
Command::Commit {
transaction_depth,
tx,
} => {
if let Err(error) = handle_cancelled_begin(&mut conn, transaction_depth)
{
tx.send(Err(error)).ok();
continue;
}

let depth = conn.transaction_depth;
let res = if depth > 0 {
conn.handle
.exec(commit_ansi_transaction_sql(depth))
Expand All @@ -180,9 +218,26 @@ impl ConnectionWorker {

tx.send(res).ok();
}
Command::Rollback { tx } => {
let depth = conn.transaction_depth;
Command::Rollback {
transaction_depth,
tx,
} => {
match handle_cancelled_begin_or_commit_or_rollback(
&mut conn,
transaction_depth,
) {
Ok(true) => (),
Ok(false) => continue,
Err(error) => {
if let Some(tx) = tx {
tx.send(Err(error)).ok();
}

continue;
}
}

let depth = conn.transaction_depth;
let res = if depth > 0 {
conn.handle
.exec(rollback_ansi_transaction_sql(depth))
Expand Down Expand Up @@ -259,6 +314,7 @@ impl ConnectionWorker {
query: query.into(),
arguments: args.map(SqliteArguments::into_static),
persistent,
transaction_depth: self.transaction_depth,
tx,
})
.await
Expand All @@ -268,21 +324,55 @@ impl ConnectionWorker {
}

pub(crate) async fn begin(&mut self) -> Result<(), Error> {
self.oneshot_cmd(|tx| Command::Begin { tx }).await?
let transaction_depth = self.transaction_depth;

self.oneshot_cmd(|tx| Command::Begin {
transaction_depth,
tx,
})
.await??;

self.transaction_depth += 1;

Ok(())
}

pub(crate) async fn commit(&mut self) -> Result<(), Error> {
self.oneshot_cmd(|tx| Command::Commit { tx }).await?
let transaction_depth = self.transaction_depth;

self.oneshot_cmd(|tx| Command::Commit {
transaction_depth,
tx,
})
.await??;

self.transaction_depth -= 1;

Ok(())
}

pub(crate) async fn rollback(&mut self) -> Result<(), Error> {
self.oneshot_cmd(|tx| Command::Rollback { tx: Some(tx) })
.await?
let transaction_depth = self.transaction_depth;

self.oneshot_cmd(|tx| Command::Rollback {
transaction_depth,
tx: Some(tx),
})
.await??;

self.transaction_depth -= 1;

Ok(())
}

pub(crate) fn start_rollback(&mut self) -> Result<(), Error> {
self.transaction_depth -= 1;

self.command_tx
.send(Command::Rollback { tx: None })
.send(Command::Rollback {
transaction_depth: self.transaction_depth,
tx: None,
})
.map_err(|_| Error::WorkerCrashed)
}

Expand Down Expand Up @@ -387,3 +477,58 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement<'s
fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
size.store(conn.statements.len(), Ordering::Release);
}

// If a `begin()` is cancelled before completion it might happen that the `Begin` command is still
// sent to the worker thread but no `Transaction` is created and so there is no way to commit it or
// roll it back. This function detects such case and handles it by automatically rolling the
// transaction back.
//
// Use only when handling an `Execute`, `Begin` or `Commit` command.
fn handle_cancelled_begin(
conn: &mut ConnectionState,
expected_transaction_depth: usize,
) -> Result<(), Error> {
if expected_transaction_depth != conn.transaction_depth {
if expected_transaction_depth == conn.transaction_depth - 1 {
let depth = conn.transaction_depth;
conn.handle.exec(rollback_ansi_transaction_sql(depth))?;
conn.transaction_depth -= 1;
} else {
// This would indicate cancelled `commit` or `rollback`, but that can only happen when
// handling a `Rollback` command because `commit()` / `rollback()` take the
// transaction by value and so when they are cancelled the transaction is immediately
// dropped which sends a `Rollback`.
unreachable!()
}
}

Ok(())
}

// Same as `handle_cancelled_begin` but additionally handles cancelled `commit()` and `rollback()`
// as well. If `commit()` / `rollback()` is cancelled, it might happen that the corresponding
// `Commit` / `Rollback` command is still sent to the worker thread but the transaction's `open`
// flag is not set to `false` which causes another `Rollback` to be sent when the transaction
// is dropped. This function detects that case and indicates to ignore the superfluous `Rollback`.
//
// Use only when handling a `Rollback` command.
fn handle_cancelled_begin_or_commit_or_rollback(
conn: &mut ConnectionState,
expected_transaction_depth: usize,
) -> Result<bool, Error> {
if expected_transaction_depth != conn.transaction_depth {
if expected_transaction_depth == conn.transaction_depth - 1 {
let depth = conn.transaction_depth;
conn.handle.exec(rollback_ansi_transaction_sql(depth))?;
conn.transaction_depth -= 1;

Ok(true)
} else if expected_transaction_depth == conn.transaction_depth + 1 {
Ok(false)
} else {
unreachable!()
}
} else {
Ok(true)
}
}

0 comments on commit acc3801

Please sign in to comment.