Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add begin_with to start transaction with custom SQL #3322

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions sqlx-core/src/acquire.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::database::Database;
use crate::error::Error;
use crate::pool::{MaybePoolConnection, Pool, PoolConnection};
use std::borrow::Cow;

use crate::transaction::Transaction;
use futures_core::future::BoxFuture;
Expand Down Expand Up @@ -78,6 +79,10 @@ pub trait Acquire<'c> {
fn acquire(self) -> BoxFuture<'c, Result<Self::Connection, Error>>;

fn begin(self) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, Error>>;

fn begin_with<S>(self, sql: S) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, Error>>
where
S: Into<Cow<'static, str>> + Send + 'c;
}

impl<'a, DB: Database> Acquire<'a> for &'_ Pool<DB> {
Expand All @@ -96,6 +101,17 @@ impl<'a, DB: Database> Acquire<'a> for &'_ Pool<DB> {
Transaction::begin(MaybePoolConnection::PoolConnection(conn.await?)).await
})
}

fn begin_with<S>(self, sql: S) -> BoxFuture<'a, Result<Transaction<'a, Self::Database>, Error>>
where
S: Into<Cow<'static, str>> + Send + 'a,
{
let conn = self.acquire();

Box::pin(async move {
Transaction::begin_with(MaybePoolConnection::PoolConnection(conn.await?), sql).await
})
}
}

#[macro_export]
Expand Down Expand Up @@ -123,6 +139,20 @@ macro_rules! impl_acquire {
> {
$crate::transaction::Transaction::begin(self)
}

#[inline]
fn begin_with<S>(
self,
stmt: S,
) -> futures_core::future::BoxFuture<
'c,
Result<$crate::transaction::Transaction<'c, $DB>, $crate::error::Error>,
>
where
S: Into<std::borrow::Cow<'static, str>> + Send + 'c,
{
$crate::transaction::Transaction::begin_with(self, stmt)
}
}
};
}
6 changes: 6 additions & 0 deletions sqlx-core/src/any/connection/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::describe::Describe;
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use std::borrow::Cow;
use std::fmt::Debug;

pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static {
Expand Down Expand Up @@ -30,6 +31,11 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static {
/// Returns a [`Transaction`] for controlling and tracking the new transaction.
fn begin(&mut self) -> BoxFuture<'_, crate::Result<()>>;

/// Begin a new transaction or establish a savepoint within the active transaction.
///
/// Returns a [`Transaction`] for controlling and tracking the new transaction.
fn begin_with(&mut self, sql: Cow<'static, str>) -> BoxFuture<'_, crate::Result<()>>;

fn commit(&mut self) -> BoxFuture<'_, crate::Result<()>>;

fn rollback(&mut self) -> BoxFuture<'_, crate::Result<()>>;
Expand Down
12 changes: 12 additions & 0 deletions sqlx-core/src/any/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use futures_util::future::BoxFuture;
use std::borrow::Cow;

use crate::any::{Any, AnyConnection};
use crate::database::Database;
use crate::error::Error;
use crate::transaction::TransactionManager;

Expand All @@ -13,6 +15,16 @@ impl TransactionManager for AnyTransactionManager {
conn.backend.begin()
}

fn begin_with<'a, S>(
conn: &'a mut <Self::Database as Database>::Connection,
sql: S,
) -> BoxFuture<'a, Result<(), Error>>
where
S: Into<Cow<'static, str>> + Send + 'a,
{
conn.backend.begin_with(sql.into())
}

fn commit(conn: &mut AnyConnection) -> BoxFuture<'_, Result<(), Error>> {
conn.backend.commit()
}
Expand Down
10 changes: 10 additions & 0 deletions sqlx-core/src/pool/connection.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::fmt::{self, Debug, Formatter};
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
Expand All @@ -11,6 +12,8 @@ use crate::error::Error;

use super::inner::{is_beyond_max_lifetime, DecrementSizeGuard, PoolInner};
use crate::pool::options::PoolConnectionMetadata;
use crate::transaction::Transaction;
use futures_core::future::BoxFuture;
use std::future::Future;

/// A connection managed by a [`Pool`][crate::pool::Pool].
Expand Down Expand Up @@ -159,6 +162,13 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB
{
crate::transaction::Transaction::begin(&mut **self)
}

fn begin_with<S>(self, sql: S) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, Error>>
where
S: Into<Cow<'static, str>> + Send + 'c,
{
crate::transaction::Transaction::begin_with(&mut **self, sql)
}
}

/// Returns the connection to the [`Pool`][crate::pool::Pool] it was checked-out from.
Expand Down
34 changes: 34 additions & 0 deletions sqlx-core/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ pub trait TransactionManager {
conn: &mut <Self::Database as Database>::Connection,
) -> BoxFuture<'_, Result<(), Error>>;

fn begin_with<'a, S>(
conn: &'a mut <Self::Database as Database>::Connection,
sql: S,
) -> BoxFuture<'a, Result<(), Error>>
where
S: Into<Cow<'static, str>> + Send + 'a;

/// Commit the active transaction or release the most recent savepoint.
fn commit(
conn: &mut <Self::Database as Database>::Connection,
Expand Down Expand Up @@ -78,6 +85,26 @@ where
})
}

#[doc(hidden)]
pub fn begin_with<S>(
conn: impl Into<MaybePoolConnection<'c, DB>>,
sql: S,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be Cow<'static, str> here since it's not a public API. That will save on codgen.

) -> BoxFuture<'c, Result<Self, Error>>
where
S: Into<Cow<'static, str>> + Send + 'c,
{
let mut conn = conn.into();

Box::pin(async move {
DB::TransactionManager::begin_with(&mut conn, sql).await?;

Ok(Self {
connection: conn,
open: true,
})
})
}

/// Commits this transaction or savepoint.
pub async fn commit(mut self) -> Result<(), Error> {
DB::TransactionManager::commit(&mut self.connection).await?;
Expand Down Expand Up @@ -221,6 +248,13 @@ impl<'c, 't, DB: Database> crate::acquire::Acquire<'t> for &'t mut Transaction<'
fn begin(self) -> BoxFuture<'t, Result<Transaction<'t, DB>, Error>> {
Transaction::begin(&mut **self)
}

fn begin_with<S>(self, sql: S) -> BoxFuture<'t, Result<Transaction<'t, Self::Database>, Error>>
where
S: Into<Cow<'static, str>> + Send + 't,
{
Transaction::begin_with(&mut **self, sql)
}
}

impl<'c, DB> Drop for Transaction<'c, DB>
Expand Down
5 changes: 5 additions & 0 deletions sqlx-mysql/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use sqlx_core::database::Database;
use sqlx_core::describe::Describe;
use sqlx_core::executor::Executor;
use sqlx_core::transaction::TransactionManager;
use std::borrow::Cow;

sqlx_core::declare_driver_with_optional_migrate!(DRIVER = MySql);

Expand All @@ -40,6 +41,10 @@ impl AnyConnectionBackend for MySqlConnection {
MySqlTransactionManager::begin(self)
}

fn begin_with(&mut self, sql: Cow<'static, str>) -> BoxFuture<'_, sqlx_core::Result<()>> {
MySqlTransactionManager::begin_with(self, sql)
}

fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
MySqlTransactionManager::commit(self)
}
Expand Down
22 changes: 20 additions & 2 deletions sqlx-mysql/src/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use futures_core::future::BoxFuture;

use crate::connection::Waiting;
use crate::error::Error;
use crate::executor::Executor;
use crate::protocol::text::Query;
use crate::{MySql, MySqlConnection};
use futures_core::future::BoxFuture;
use sqlx_core::database::Database;
use std::borrow::Cow;

pub(crate) use sqlx_core::transaction::*;

Expand All @@ -25,6 +26,23 @@ impl TransactionManager for MySqlTransactionManager {
})
}

fn begin_with<'a, S>(
conn: &'a mut <Self::Database as Database>::Connection,
sql: S,
) -> BoxFuture<'a, Result<(), Error>>
where
S: Into<Cow<'static, str>> + Send + 'a,
{
Box::pin(async move {
let depth = conn.transaction_depth;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here re. checking transaction_depth.


conn.execute(&*sql.into()).await?;
conn.transaction_depth = depth + 1;

Ok(())
})
}

fn commit(conn: &mut MySqlConnection) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
let depth = conn.transaction_depth;
Expand Down
5 changes: 5 additions & 0 deletions sqlx-postgres/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_util::{StreamExt, TryFutureExt, TryStreamExt};
use std::borrow::Cow;

pub use sqlx_core::any::*;

Expand Down Expand Up @@ -39,6 +40,10 @@ impl AnyConnectionBackend for PgConnection {
PgTransactionManager::begin(self)
}

fn begin_with(&mut self, sql: Cow<'static, str>) -> BoxFuture<'_, sqlx_core::Result<()>> {
PgTransactionManager::begin_with(self, sql)
}

fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
PgTransactionManager::commit(self)
}
Expand Down
23 changes: 21 additions & 2 deletions sqlx-postgres/src/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use futures_core::future::BoxFuture;

use crate::error::Error;
use crate::executor::Executor;
use futures_core::future::BoxFuture;
use sqlx_core::database::Database;
use std::borrow::Cow;

use crate::{PgConnection, Postgres};

Expand All @@ -26,6 +27,24 @@ impl TransactionManager for PgTransactionManager {
})
}

fn begin_with<'a, S>(
conn: &'a mut <Self::Database as Database>::Connection,
sql: S,
) -> BoxFuture<'a, Result<(), Error>>
where
S: Into<Cow<'static, str>> + Send + 'a,
{
Box::pin(async move {
let rollback = Rollback::new(conn);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should check and return an error if conn.transaction_depth > 0 because it doesn't make sense to allow this to create and release manual savepoints when a transaction is already in progress; that's just going to break things.

rollback.conn.queue_simple_query(&sql.into());
rollback.conn.transaction_depth += 1;
rollback.conn.wait_until_ready().await?;
rollback.defuse();

Ok(())
})
}

fn commit(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
if conn.transaction_depth > 0 {
Expand Down
5 changes: 5 additions & 0 deletions sqlx-sqlite/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_util::{StreamExt, TryFutureExt, TryStreamExt};
use std::borrow::Cow;

use sqlx_core::any::{
Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow,
Expand Down Expand Up @@ -41,6 +42,10 @@ impl AnyConnectionBackend for SqliteConnection {
SqliteTransactionManager::begin(self)
}

fn begin_with(&mut self, sql: Cow<'static, str>) -> BoxFuture<'_, sqlx_core::Result<()>> {
todo!()
}

fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
SqliteTransactionManager::commit(self)
}
Expand Down
15 changes: 13 additions & 2 deletions sqlx-sqlite/src/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use futures_core::future::BoxFuture;

use crate::{Sqlite, SqliteConnection};
use futures_core::future::BoxFuture;
use sqlx_core::database::Database;
use sqlx_core::error::Error;
use sqlx_core::transaction::TransactionManager;
use std::borrow::Cow;

/// Implementation of [`TransactionManager`] for SQLite.
pub struct SqliteTransactionManager;
Expand All @@ -14,6 +15,16 @@ impl TransactionManager for SqliteTransactionManager {
Box::pin(conn.worker.begin())
}

fn begin_with<'a, S>(
_conn: &'a mut <Self::Database as Database>::Connection,
sql: S,
) -> BoxFuture<'a, Result<(), Error>>
where
S: Into<Cow<'static, str>> + Send + 'a,
{
unimplemented!()
}

fn commit(conn: &mut SqliteConnection) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(conn.worker.commit())
}
Expand Down