diff --git a/appveyor.yml b/appveyor.yml index 78518e21..5d02b6f2 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -23,13 +23,13 @@ environment: PATH: C:\msys64\mingw64\bin\;c:\rust\bin;%PATH% - TARGET: beta-i686-pc-windows-gnu PATH: C:\msys64\mingw32\bin\;c:\rust\bin;%PATH% - - TARGET: 1.39.0-x86_64-pc-windows-msvc + - TARGET: 1.42.0-x86_64-pc-windows-msvc PATH: C:\msys64\mingw64\bin\;c:\rust\bin;%PATH% - - TARGET: 1.39.0-i686-pc-windows-msvc + - TARGET: 1.42.0-i686-pc-windows-msvc PATH: C:\msys64\mingw32\bin\;c:\rust\bin;%PATH% - - TARGET: 1.39.0-x86_64-pc-windows-gnu + - TARGET: 1.42.0-x86_64-pc-windows-gnu PATH: C:\msys64\mingw64\bin\;c:\rust\bin;%PATH% - - TARGET: 1.39.0-i686-pc-windows-gnu + - TARGET: 1.42.0-i686-pc-windows-gnu PATH: C:\msys64\mingw32\bin\;c:\rust\bin;%PATH% services: mysql install: diff --git a/src/conn/mod.rs b/src/conn/mod.rs index e939a16d..013a2098 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -12,8 +12,9 @@ use mysql_common::{ constants::DEFAULT_MAX_ALLOWED_PACKET, crypto, packets::{ - parse_auth_switch_request, parse_handshake_packet, AuthPlugin, AuthSwitchRequest, - HandshakeResponse, OkPacket, SslRequest, + parse_auth_switch_request, parse_err_packet, parse_handshake_packet, parse_ok_packet, + AuthPlugin, AuthSwitchRequest, ErrPacket, HandshakeResponse, OkPacket, OkPacketKind, + SslRequest, }, }; @@ -24,22 +25,22 @@ use std::{ mem, pin::Pin, str::FromStr, - sync::Arc, time::{Duration, Instant}, }; use crate::{ conn::{pool::Pool, stmt_cache::StmtCache}, connection_like::ConnectionLike, - consts::{self, CapabilityFlags}, + consts::{CapabilityFlags, Command, StatusFlags}, error::*, io::Stream, - local_infile_handler::LocalInfileHandler, opts::Opts, queryable::{ - query_result::QueryResult, transaction::TxStatus, BinaryProtocol, Queryable, TextProtocol, + query_result::{QueryResult, ResultSetMeta}, + transaction::TxStatus, + BinaryProtocol, Queryable, TextProtocol, }, - Column, OptsBuilder, + OptsBuilder, }; pub mod pool; @@ -70,25 +71,18 @@ fn disconnect(mut conn: Conn) { } } -#[derive(Debug)] -pub enum PendingResult { - Text(Arc>), - Binary(Arc>), - Empty, -} - /// Mysql connection struct ConnInner { stream: Option, id: u32, version: (u16, u16, u16), - max_allowed_packet: usize, socket: Option, - capabilities: consts::CapabilityFlags, - status: consts::StatusFlags, + capabilities: CapabilityFlags, + status: StatusFlags, last_ok_packet: Option>, + last_err_packet: Option>, pool: Option, - has_result: Option, + has_result: Option, tx_status: TxStatus, opts: Opts, last_io: Instant, @@ -120,10 +114,10 @@ impl ConnInner { fn empty(opts: Opts) -> ConnInner { ConnInner { capabilities: opts.get_capabilities(), - status: consts::StatusFlags::empty(), + status: StatusFlags::empty(), last_ok_packet: None, + last_err_packet: None, stream: None, - max_allowed_packet: DEFAULT_MAX_ALLOWED_PACKET, version: (0, 0, 0), id: 0, has_result: None, @@ -140,6 +134,17 @@ impl ConnInner { disconnected: false, } } + + /// Returns mutable reference to a connection stream. + /// + /// # Panic + /// + /// Will panic if stream is already taken. + fn stream_mut(&mut self) -> &mut Stream { + self.stream + .as_mut() + .expect("call to stream_mut on invalid connection") + } } #[derive(Debug)] @@ -149,7 +154,7 @@ pub struct Conn { impl Conn { /// Returns connection identifier. - pub fn connection_id(&self) -> u32 { + pub fn id(&self) -> u32 { self.inner.id } @@ -157,13 +162,116 @@ impl Conn { /// `AUTO_INCREMENT` attribute. Returns `None` if there was no previous query on the connection /// or if the query did not update an AUTO_INCREMENT value. pub fn last_insert_id(&self) -> Option { - self.get_last_insert_id() + self.inner + .last_ok_packet + .as_ref() + .and_then(|ok| ok.last_insert_id()) } /// Returns the number of rows affected by the last `INSERT`, `UPDATE`, `REPLACE` or `DELETE` /// query. pub fn affected_rows(&self) -> u64 { - self.get_affected_rows() + self.inner + .last_ok_packet + .as_ref() + .map(|ok| ok.affected_rows()) + .unwrap_or_default() + } + + /// Text information, as reported by the server in the last OK packet, or an empty string. + pub fn info(&self) -> Cow<'_, str> { + self.inner + .last_ok_packet + .as_ref() + .and_then(|ok| ok.info_str()) + .unwrap_or_else(|| "".into()) + } + + /// Number of warnings, as reported by the server in the last OK packet, or `0`. + pub fn get_warnings(&self) -> u16 { + self.inner + .last_ok_packet + .as_ref() + .map(|ok| ok.warnings()) + .unwrap_or_default() + } + + pub(crate) fn stream_mut(&mut self) -> &mut Stream { + self.inner.stream_mut() + } + + pub(crate) fn capabilities(&self) -> CapabilityFlags { + self.inner.capabilities + } + + /// Will update last IO time for this connection. + pub(crate) fn touch(&mut self) { + self.inner.last_io = Instant::now(); + } + + /// Will set packet sequence id to `0`. + pub(crate) fn reset_seq_id(&mut self) { + if let Some(stream) = self.inner.stream.as_mut() { + stream.reset_seq_id(); + } + } + + /// Will syncronize sequence ids between compressed and uncompressed codecs. + pub(crate) fn sync_seq_id(&mut self) { + if let Some(stream) = self.inner.stream.as_mut() { + stream.sync_seq_id(); + } + } + + /// Handles OK packet. + pub(crate) fn handle_ok(&mut self, ok_packet: OkPacket<'static>) { + self.inner.status = ok_packet.status_flags(); + self.inner.last_err_packet = None; + self.inner.last_ok_packet = Some(ok_packet); + } + + /// Handles ERR packet. + pub(crate) fn handle_err(&mut self, err_packet: ErrPacket<'static>) { + self.inner.status = StatusFlags::empty(); + self.inner.last_ok_packet = None; + self.inner.last_err_packet = Some(err_packet); + } + + /// Returns the current transaction status. + pub(crate) fn get_tx_status(&self) -> TxStatus { + self.inner.tx_status + } + + /// Sets the given transaction status for this connection. + pub(crate) fn set_tx_status(&mut self, tx_status: TxStatus) { + self.inner.tx_status = tx_status; + } + + /// Returns pending result metadata, if any. + /// + /// If `Some(_)`, then result is not yet consumed. + pub(crate) fn get_pending_result(&self) -> Option<&ResultSetMeta> { + self.inner.has_result.as_ref() + } + + /// Sets the given pening result metadata for this connection. + pub(crate) fn set_pending_result(&mut self, meta: Option) { + self.inner.has_result = meta; + } + + /// Returns current status flags. + pub(crate) fn status(&self) -> StatusFlags { + self.inner.status + } + + /// Returns server version. + pub fn server_version(&self) -> (u16, u16, u16) { + self.inner.version + } + + /// Returns connection options. + pub fn opts(&self) -> &Opts { + &self.inner.opts } fn take_stream(&mut self) -> Stream { @@ -172,9 +280,8 @@ impl Conn { /// Disconnects this connection from server. pub async fn disconnect(mut self) -> Result<()> { - self.on_disconnect(); - self.write_command_data(crate::consts::Command::COM_QUIT, &[]) - .await?; + self.inner.disconnected = true; + self.write_command_data(Command::COM_QUIT, &[]).await?; let stream = self.take_stream(); stream.close().await?; Ok(()) @@ -251,12 +358,8 @@ impl Conn { let ssl_request = SslRequest::new(self.inner.capabilities); self.write_packet(ssl_request.as_ref()).await?; let conn = self; - let ssl_opts = conn - .get_opts() - .get_ssl_opts() - .cloned() - .expect("unreachable"); - let domain = conn.get_opts().get_ip_or_hostname().into(); + let ssl_opts = conn.opts().get_ssl_opts().cloned().expect("unreachable"); + let domain = conn.opts().get_ip_or_hostname().into(); conn.stream_mut().make_secure(domain, ssl_opts).await?; Ok(()) } else { @@ -276,7 +379,7 @@ impl Conn { self.inner.opts.get_user(), self.inner.opts.get_db_name(), &self.inner.auth_plugin, - self.get_capabilities(), + self.capabilities(), &Default::default(), // TODO: Add support ); @@ -327,7 +430,7 @@ impl Conn { fn switch_to_compression(&mut self) -> Result<()> { if self - .get_capabilities() + .capabilities() .contains(CapabilityFlags::CLIENT_COMPRESS) { if let Some(compression) = self.inner.opts.get_compression() { @@ -405,6 +508,74 @@ impl Conn { } } + fn handle_packet(&mut self, packet: &[u8]) -> Result<()> { + let kind = if self.get_pending_result().is_some() { + OkPacketKind::ResultSetTerminator + } else { + OkPacketKind::Other + }; + + if let Ok(ok_packet) = parse_ok_packet(&*packet, self.capabilities(), kind) { + self.handle_ok(ok_packet.into_owned()); + } else if let Ok(err_packet) = parse_err_packet(&*packet, self.capabilities()) { + self.handle_err(err_packet.clone().into_owned()); + return Err(err_packet.into()).into(); + } + + Ok(()) + } + + pub(crate) async fn read_packet(&mut self) -> Result> { + let packet = crate::io::ReadPacket::new(self).await.map_err(|io_err| { + self.inner.stream.take(); + self.inner.disconnected = true; + Error::from(io_err) + })?; + self.handle_packet(&*packet)?; + Ok(packet) + } + + /// Returns future that reads packets from a server. + pub(crate) async fn read_packets(&mut self, n: usize) -> Result>> { + let mut packets = Vec::with_capacity(n); + for _ in 0..n { + packets.push(self.read_packet().await?); + } + Ok(packets) + } + + pub(crate) async fn write_packet(&mut self, data: T) -> Result<()> + where + T: Into>, + { + crate::io::WritePacket::new(self, data.into()) + .await + .map_err(|io_err| { + self.inner.stream.take(); + self.inner.disconnected = true; + From::from(io_err) + }) + } + + /// Returns future that sends full command body to a server. + pub(crate) async fn write_command_raw(&mut self, body: Vec) -> Result<()> { + debug_assert!(body.len() > 0); + self.conn_mut().reset_seq_id(); + self.write_packet(body).await + } + + /// Returns future that writes command to a server. + pub(crate) async fn write_command_data(&mut self, cmd: Command, cmd_data: T) -> Result<()> + where + T: AsRef<[u8]>, + { + let cmd_data = cmd_data.as_ref(); + let mut body = Vec::with_capacity(1 + cmd_data.len()); + body.push(cmd as u8); + body.extend_from_slice(cmd_data); + self.write_command_raw(body).await + } + async fn drop_packet(&mut self) -> Result<()> { self.read_packet().await?; Ok(()) @@ -528,7 +699,7 @@ impl Conn { let pool = self.inner.pool.clone(); if self.inner.version > (5, 7, 2) { - self.write_command_data(consts::Command::COM_RESET_CONNECTION, &[]) + self.write_command_data(Command::COM_RESET_CONNECTION, &[]) .await?; self.read_packet().await?; } else { @@ -552,20 +723,20 @@ impl Conn { pub(crate) async fn drop_result(&mut self) -> Result<()> { match self.inner.has_result.take() { - Some(PendingResult::Text(columns)) => { - QueryResult::<'_, _, TextProtocol>::new(self, Some(columns)) + Some(meta @ ResultSetMeta::Text(_)) => { + QueryResult::<'_, _, TextProtocol>::new(self, meta) .drop_result() .await?; Ok(()) } - Some(PendingResult::Binary(columns)) => { - QueryResult::<'_, _, BinaryProtocol>::new(self, Some(columns)) + Some(meta @ ResultSetMeta::Binary(_)) => { + QueryResult::<'_, _, BinaryProtocol>::new(self, meta) .drop_result() .await?; Ok(()) } - Some(PendingResult::Empty) => { - QueryResult::<'_, _, TextProtocol>::new(self, None) + Some(meta @ ResultSetMeta::Empty) => { + QueryResult::<'_, _, TextProtocol>::new(self, meta) .drop_result() .await?; Ok(()) @@ -589,123 +760,12 @@ impl Conn { } impl ConnectionLike for Conn { - fn conn_mut(&mut self) -> &mut crate::Conn { + fn conn_ref(&self) -> &crate::Conn { self } - fn connection_id(&self) -> u32 { - self.connection_id() - } - - fn stream_mut(&mut self) -> &mut Stream { - self.inner.stream.as_mut().expect("Logic error: stream") - } - - fn get_affected_rows(&self) -> u64 { - self.inner - .last_ok_packet - .as_ref() - .map(|ok| ok.affected_rows()) - .unwrap_or_default() - } - - fn get_capabilities(&self) -> consts::CapabilityFlags { - self.inner.capabilities - } - - fn get_tx_status(&self) -> TxStatus { - self.inner.tx_status - } - - fn get_last_insert_id(&self) -> Option { - self.inner - .last_ok_packet - .as_ref() - .and_then(|ok| ok.last_insert_id()) - } - - fn get_info(&self) -> Cow<'_, str> { - self.inner - .last_ok_packet - .as_ref() - .and_then(|ok| ok.info_str()) - .unwrap_or_else(|| "".into()) - } - - fn get_warnings(&self) -> u16 { - self.inner - .last_ok_packet - .as_ref() - .map(|ok| ok.warnings()) - .unwrap_or_default() - } - - fn get_local_infile_handler(&self) -> Option> { - self.inner.opts.get_local_infile_handler() - } - - fn get_max_allowed_packet(&self) -> usize { - self.inner.max_allowed_packet - } - - fn get_opts(&self) -> &Opts { - &self.inner.opts - } - - fn get_pending_result(&self) -> Option<&PendingResult> { - self.inner.has_result.as_ref() - } - - fn get_server_version(&self) -> (u16, u16, u16) { - self.inner.version - } - - fn get_status(&self) -> consts::StatusFlags { - self.inner.status - } - - fn set_last_ok_packet(&mut self, ok_packet: Option>) { - self.inner.last_ok_packet = ok_packet; - } - - fn set_tx_status(&mut self, tx_status: TxStatus) { - self.inner.tx_status = tx_status; - } - - fn set_pending_result(&mut self, meta: Option) { - self.inner.has_result = meta; - } - - fn set_status(&mut self, status: consts::StatusFlags) { - self.inner.status = status; - } - - fn reset_seq_id(&mut self) { - if let Some(stream) = self.inner.stream.as_mut() { - stream.reset_seq_id(); - } - } - - fn sync_seq_id(&mut self) { - if let Some(stream) = self.inner.stream.as_mut() { - stream.sync_seq_id(); - } - } - - fn touch(&mut self) { - self.inner.last_io = Instant::now(); - } - - fn on_disconnect(&mut self) { - self.inner.disconnected = true; - } - - fn stmt_cache_ref(&self) -> &StmtCache { - &self.inner.stmt_cache - } - - fn stmt_cache_mut(&mut self) -> &mut StmtCache { - &mut self.inner.stmt_cache + fn conn_mut(&mut self) -> &mut crate::Conn { + self } } @@ -756,10 +816,17 @@ mod test { let mut tx = conn.start_transaction(Default::default()).await?; tx.query_drop("INSERT INTO mysql.foo (id) VALUES (42)") .await?; - tx.exec_iter("SELECT * FROM mysql.foo", ()).await?; + tx.exec_iter("SELECT COUNT(*) FROM mysql.foo", ()).await?; drop(tx); conn.ping().await?; + let count: u8 = conn + .query_first("SELECT COUNT(*) FROM mysql.foo") + .await? + .unwrap_or_default(); + + assert_eq!(count, 0); + Ok(()) } @@ -887,8 +954,6 @@ mod test { #[tokio::test] async fn should_hold_stmt_cache_size_bound() -> super::Result<()> { - use crate::connection_like::ConnectionLike; - let mut opts = OptsBuilder::from_opts(get_opts()); opts.stmt_cache_size(3); let mut conn = Conn::new(opts).await?; @@ -1383,99 +1448,58 @@ mod test { #[cfg(feature = "nightly")] mod bench { - use futures_util::try_future::TryFutureExt; - use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts}; #[bench] fn simple_exec(bencher: &mut test::Bencher) { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let mut conn_opt = Some(runtime.block_on(Conn::new(get_opts())).unwrap()); + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap(); bencher.iter(|| { - let conn = conn_opt.take().unwrap(); - conn_opt = Some(runtime.block_on(conn.query_drop("DO 1")).unwrap()); + runtime.block_on(conn.query_drop("DO 1")).unwrap(); }); - runtime - .block_on(conn_opt.take().unwrap().disconnect()) - .unwrap(); - runtime.shutdown_on_idle(); + runtime.block_on(conn.disconnect()).unwrap(); } #[bench] fn select_large_string(bencher: &mut test::Bencher) { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let mut conn_opt = Some(runtime.block_on(Conn::new(get_opts())).unwrap()); + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap(); bencher.iter(|| { - let conn = conn_opt.take().unwrap(); - conn_opt = Some( - runtime - .block_on(conn.query_drop("SELECT REPEAT('A', 10000)")) - .unwrap(), - ); + runtime + .block_on(conn.query_drop("SELECT REPEAT('A', 10000)")) + .unwrap(); }); - runtime - .block_on(conn_opt.take().unwrap().disconnect()) - .unwrap(); - runtime.shutdown_on_idle(); + runtime.block_on(conn.disconnect()).unwrap(); } #[bench] fn prepared_exec(bencher: &mut test::Bencher) { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let mut stmt_opt = Some( - runtime - .block_on(Conn::new(get_opts()).and_then(|conn| conn.prepare("DO 1"))) - .unwrap(), - ); + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap(); + let stmt = runtime.block_on(conn.prep("DO 1")).unwrap(); bencher.iter(|| { - let stmt = stmt_opt.take().unwrap(); - stmt_opt = Some( - runtime - .block_on(stmt.execute(()).and_then(|result| result.drop_result())) - .unwrap(), - ); + runtime.block_on(conn.exec_drop(&stmt, ())).unwrap(); }); - runtime - .block_on( - stmt_opt - .take() - .unwrap() - .close() - .and_then(|conn| conn.disconnect()), - ) - .unwrap(); - runtime.shutdown_on_idle(); + runtime.block_on(conn.close(stmt)).unwrap(); + runtime.block_on(conn.disconnect()).unwrap(); } #[bench] fn prepare_and_exec(bencher: &mut test::Bencher) { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let mut conn_opt = Some(runtime.block_on(Conn::new(get_opts())).unwrap()); + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + let mut conn = runtime.block_on(Conn::new(get_opts())).unwrap(); bencher.iter(|| { - let conn = conn_opt.take().unwrap(); - conn_opt = Some( - runtime - .block_on( - conn.prepare("SELECT ?") - .and_then(|stmt| stmt.execute((0,))) - .and_then(|result| result.drop_result()) - .and_then(|stmt| stmt.close()), - ) - .unwrap(), - ); + runtime.block_on(conn.exec_drop("SELECT ?", (0,))).unwrap(); }); - runtime - .block_on(conn_opt.take().unwrap().disconnect()) - .unwrap(); - runtime.shutdown_on_idle(); + runtime.block_on(conn.disconnect()).unwrap(); } } } diff --git a/src/conn/pool/futures/disconnect_pool.rs b/src/conn/pool/futures/disconnect_pool.rs index 8af5f7c6..9a0d3244 100644 --- a/src/conn/pool/futures/disconnect_pool.rs +++ b/src/conn/pool/futures/disconnect_pool.rs @@ -24,6 +24,8 @@ use std::sync::{atomic, Arc}; /// /// **Note:** This Future won't resolve until all active connections, taken from it, /// are dropped or disonnected. Also all pending and new `GetConn`'s will resolve to error. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] pub struct DisconnectPool { pool_inner: Arc, } diff --git a/src/conn/pool/futures/get_conn.rs b/src/conn/pool/futures/get_conn.rs index 3b9fd771..7f35f659 100644 --- a/src/conn/pool/futures/get_conn.rs +++ b/src/conn/pool/futures/get_conn.rs @@ -19,11 +19,14 @@ use crate::{ BoxFuture, }; +#[derive(Debug)] pub(crate) enum GetConnInner { New, - Done(Option), + Done, // TODO: one day this should be an existential Connecting(BoxFuture<'static, Conn>), + /// This future will check, that idling connection is alive. + Checking(BoxFuture<'static, Conn>), } impl GetConnInner { @@ -34,15 +37,31 @@ impl GetConnInner { } /// This future will take connection from a pool and resolve to [`Conn`]. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] pub struct GetConn { pub(crate) pool: Option, pub(crate) inner: GetConnInner, } -pub fn new(pool: &Pool) -> GetConn { - GetConn { - pool: Some(pool.clone()), - inner: GetConnInner::New, +impl GetConn { + pub(crate) fn new(pool: &Pool) -> GetConn { + GetConn { + pool: Some(pool.clone()), + inner: GetConnInner::New, + } + } + + fn pool_mut(&mut self) -> &mut Pool { + self.pool + .as_mut() + .expect("GetConn::poll polled after returning Async::Ready") + } + + fn pool_take(&mut self) -> Pool { + self.pool + .take() + .expect("GetConn::poll polled after returning Async::Ready") } } @@ -54,46 +73,31 @@ impl Future for GetConn { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { match self.inner { - GetConnInner::New => match ready!(Pin::new( - self.pool - .as_mut() - .expect("GetConn::poll polled after returning Async::Ready") - ) - .poll_new_conn(cx))? - .inner - .take() + GetConnInner::New => match ready!(Pin::new(self.pool_mut()).poll_new_conn(cx))? + .inner + .take() { - GetConnInner::Done(Some(conn)) => { - self.inner = GetConnInner::Done(Some(conn)); - } GetConnInner::Connecting(conn_fut) => { self.inner = GetConnInner::Connecting(conn_fut); } - GetConnInner::Done(None) => unreachable!( + GetConnInner::Checking(conn_fut) => { + self.inner = GetConnInner::Checking(conn_fut); + } + GetConnInner::Done => unreachable!( "Pool::poll_new_conn never gives out already-consumed GetConns" ), GetConnInner::New => { unreachable!("Pool::poll_new_conn never gives out GetConnInner::New") } }, - GetConnInner::Done(ref mut c @ Some(_)) => { - let mut c = c.take().unwrap(); - c.inner.pool = Some( - self.pool - .take() - .expect("GetConn::poll polled after returning Async::Ready"), - ); - return Poll::Ready(Ok(c)); - } - GetConnInner::Done(None) => { + GetConnInner::Done => { unreachable!("GetConn::poll polled after returning Async::Ready"); } GetConnInner::Connecting(ref mut f) => { let result = ready!(Pin::new(f).poll(cx)); - let pool = self - .pool - .take() - .expect("GetConn::poll polled after returning Async::Ready"); + let pool = self.pool_take(); + + self.inner = GetConnInner::Done; return match result { Ok(mut c) => { @@ -106,6 +110,26 @@ impl Future for GetConn { } }; } + GetConnInner::Checking(ref mut f) => { + let result = ready!(Pin::new(f).poll(cx)); + match result { + Ok(mut checked_conn) => { + self.inner = GetConnInner::Done; + + let pool = self.pool_take(); + checked_conn.inner.pool = Some(pool); + return Poll::Ready(Ok(checked_conn)); + } + Err(_) => { + // Idling connection is broken. We'll drop it and try again. + self.inner = GetConnInner::New; + + let pool = self.pool_mut(); + pool.cancel_connection(); + continue; + } + } + } } } } diff --git a/src/conn/pool/futures/mod.rs b/src/conn/pool/futures/mod.rs index 24b36233..00842994 100644 --- a/src/conn/pool/futures/mod.rs +++ b/src/conn/pool/futures/mod.rs @@ -7,10 +7,7 @@ // modified, or distributed except according to those terms. pub(super) use self::get_conn::GetConnInner; -pub use self::{ - disconnect_pool::DisconnectPool, - get_conn::{new as new_get_conn, GetConn}, -}; +pub use self::{disconnect_pool::DisconnectPool, get_conn::GetConn}; mod disconnect_pool; mod get_conn; diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 759f7290..9c3a6e2e 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -60,6 +60,7 @@ impl From for IdlingConn { /// /// It is held under a single, non-asynchronous lock. /// This is fine as long as we never do expensive work while holding the lock! +#[derive(Debug)] struct Exchange { waiting: VecDeque, available: VecDeque, @@ -86,6 +87,7 @@ impl Exchange { } } +#[derive(Debug)] pub struct Inner { close: atomic::AtomicBool, closed: atomic::AtomicBool, @@ -139,7 +141,7 @@ impl Pool { /// Returns a future that resolves to [`Conn`]. pub fn get_conn(&self) -> GetConn { - new_get_conn(self) + GetConn::new(self) } /// Returns a future that disconnects this pool from the server and resolves to `()`. @@ -233,11 +235,14 @@ impl Pool { exchange.spawn_futures_if_needed(&self.inner); loop { - if let Some(IdlingConn { conn, .. }) = exchange.available.pop_back() { + if let Some(IdlingConn { mut conn, .. }) = exchange.available.pop_back() { if !conn.expired() { return Poll::Ready(Ok(GetConn { pool: Some(self.clone()), - inner: GetConnInner::Done(Some(conn)), + inner: GetConnInner::Checking(BoxFuture(Box::pin(async move { + conn.stream_mut().check().await?; + Ok(conn) + }))), })); } else { self.send_to_recycler(conn); @@ -339,6 +344,60 @@ mod test { Ok(()) } + #[tokio::test] + async fn should_reconnect() -> super::Result<()> { + let mut master = crate::Conn::new(get_opts()).await?; + + async fn test(master: &mut crate::Conn, opts: crate::OptsBuilder) -> super::Result<()> { + const NUM_CONNS: usize = 5; + let pool = Pool::new(opts); + + // create some conns.. + let connections = (0..NUM_CONNS).map(|_| { + crate::BoxFuture(Box::pin(async { + let mut conn = pool.get_conn().await?; + conn.ping().await?; + Ok(conn) + })) + }); + + // collect ids.. + let ids = try_join_all(connections) + .await? + .into_iter() + .map(|conn| (conn.id(),)) + .collect::>(); + + // get_conn should work if connection is available and alive + pool.get_conn().await?; + + // now we'll kill connections.. + master.exec_batch("KILL ?", ids).await?; + + // now check, that they're still in the pool.. + assert_eq!(ex_field!(pool, available).len(), NUM_CONNS); + + // now get new connection.. + let _conn = pool.get_conn().await?; + + // now check, that broken connections are dropped + assert_eq!(ex_field!(pool, available).len(), 0); + + drop(_conn); + pool.disconnect().await + } + + let mut opts = get_opts(); + + println!("Check socket/pipe.."); + test(&mut master, opts.clone()).await?; + opts.prefer_socket(false); + println!("Check tcp.."); + test(&mut master, opts).await?; + + master.disconnect().await + } + #[tokio::test] #[ignore] async fn can_handle_the_pressure() { @@ -567,6 +626,13 @@ mod test { .unwrap(); drop(conns); drop(pool); + + let pool = Pool::new(get_opts()); + let conns = try_join_all((0..10).map(|_| pool.get_conn())) + .await + .unwrap(); + drop(pool); + drop(conns); Ok(()) } @@ -637,19 +703,21 @@ mod test { #[cfg(feature = "nightly")] mod bench { - use futures_util::{future::FutureExt, try_future::TryFutureExt}; + use futures_util::future::{FutureExt, TryFutureExt}; use tokio::runtime::Runtime; use crate::{prelude::Queryable, test_misc::get_opts, Pool, PoolConstraints, PoolOptions}; use std::time::Duration; #[bench] - fn connect(bencher: &mut test::Bencher) { + fn get_conn(bencher: &mut test::Bencher) { let mut runtime = Runtime::new().unwrap(); let pool = Pool::new(get_opts()); bencher.iter(|| { - let fut = pool.get_conn().and_then(|conn| conn.ping()); + let fut = pool + .get_conn() + .and_then(|mut conn| async { conn.ping().await.map(|_| conn) }); runtime.block_on(fut).unwrap(); }); @@ -658,7 +726,7 @@ mod test { #[bench] fn new_conn_on_pool_soft_boundary(bencher: &mut test::Bencher) { - let runtime = Runtime::new().unwrap(); + let mut runtime = Runtime::new().unwrap(); let mut opts = get_opts(); let mut pool_opts = PoolOptions::with_constraints(PoolConstraints::new(0, 1).unwrap()); diff --git a/src/conn/pool/recycler.rs b/src/conn/pool/recycler.rs index 7acdbc59..6c9bc2f2 100644 --- a/src/conn/pool/recycler.rs +++ b/src/conn/pool/recycler.rs @@ -21,6 +21,8 @@ use super::{IdlingConn, Inner}; use crate::{queryable::transaction::TxStatus, BoxFuture, Conn, PoolOptions}; use tokio::sync::mpsc::UnboundedReceiver; +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] pub(crate) struct Recycler { inner: Arc, discard: FuturesUnordered>, diff --git a/src/conn/stmt_cache.rs b/src/conn/stmt_cache.rs index 1d2d4031..5acdf94d 100644 --- a/src/conn/stmt_cache.rs +++ b/src/conn/stmt_cache.rs @@ -54,15 +54,6 @@ impl StmtCache { } } - pub fn contains_query(&self, key: &T) -> bool - where - QueryString: Borrow, - T: Hash + Eq, - T: ?Sized, - { - self.query_map.contains_key(key) - } - pub fn by_query(&mut self, query: &T) -> Option<&Entry> where QueryString: Borrow, @@ -113,11 +104,40 @@ impl StmtCache { self.cache.iter() } - pub fn into_iter(mut self) -> impl Iterator { - std::iter::from_fn(move || self.cache.pop_lru()) - } - + #[cfg(test)] pub fn len(&self) -> usize { self.cache.len() } } + +impl super::Conn { + #[cfg(test)] + pub(crate) fn stmt_cache_ref(&self) -> &StmtCache { + &self.inner.stmt_cache + } + + pub(crate) fn stmt_cache_mut(&mut self) -> &mut StmtCache { + &mut self.inner.stmt_cache + } + + /// Caches the given statement. + /// + /// Returns LRU statement on cache capacity overflow. + pub(crate) fn cache_stmt(&mut self, stmt: &Arc) -> Option> { + let query = stmt.raw_query.clone(); + if self.inner.opts.get_stmt_cache_size() > 0 { + self.stmt_cache_mut().put(query, stmt.clone()) + } else { + None + } + } + + /// Returns statement, if cached. + /// + /// `raw_query` is the query with `?` placeholders (not with `:` placeholders). + pub(crate) fn get_cached_stmt(&mut self, raw_query: &str) -> Option> { + self.stmt_cache_mut() + .by_query(raw_query) + .map(|entry| entry.stmt.clone()) + } +} diff --git a/src/connection_like/mod.rs b/src/connection_like/mod.rs index d9a10273..ea2b467e 100644 --- a/src/connection_like/mod.rs +++ b/src/connection_like/mod.rs @@ -6,184 +6,7 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. -use mysql_common::{ - io::ReadMysqlExt, - packets::{column_from_payload, parse_local_infile_packet, Column, OkPacket}, -}; -use tokio::prelude::*; - -use std::{borrow::Cow, sync::Arc}; - -use crate::{ - conn::{stmt_cache::StmtCache, PendingResult}, - connection_like::{ - read_packet::{ReadPacket, ReadPackets}, - write_packet::WritePacket, - }, - consts::{CapabilityFlags, Command, StatusFlags}, - error::*, - io, - local_infile_handler::LocalInfileHandler, - queryable::{query_result::QueryResult, stmt::StmtInner, transaction::TxStatus, Protocol}, - BoxFuture, Opts, -}; - -pub mod read_packet; -pub mod write_packet; - pub trait ConnectionLike: Send + Sized { + fn conn_ref(&self) -> &crate::Conn; fn conn_mut(&mut self) -> &mut crate::Conn; - fn connection_id(&self) -> u32; - fn stream_mut(&mut self) -> &mut io::Stream; - fn stmt_cache_ref(&self) -> &StmtCache; - fn stmt_cache_mut(&mut self) -> &mut StmtCache; - fn get_affected_rows(&self) -> u64; - fn get_capabilities(&self) -> CapabilityFlags; - fn get_tx_status(&self) -> TxStatus; - fn get_last_insert_id(&self) -> Option; - fn get_info(&self) -> Cow<'_, str>; - fn get_warnings(&self) -> u16; - fn get_local_infile_handler(&self) -> Option>; - fn get_max_allowed_packet(&self) -> usize; - fn get_opts(&self) -> &Opts; - fn get_pending_result(&self) -> Option<&PendingResult>; - fn get_server_version(&self) -> (u16, u16, u16); - fn get_status(&self) -> StatusFlags; - fn set_last_ok_packet(&mut self, ok_packet: Option>); - fn set_tx_status(&mut self, tx_statux: TxStatus); - fn set_pending_result(&mut self, meta: Option); - fn set_status(&mut self, status: StatusFlags); - fn reset_seq_id(&mut self); - fn sync_seq_id(&mut self); - fn touch(&mut self) -> (); - fn on_disconnect(&mut self); - - fn cache_stmt<'a>(&'a mut self, stmt: &Arc) -> Option> { - let query = stmt.raw_query.clone(); - if self.get_opts().get_stmt_cache_size() > 0 { - self.stmt_cache_mut().put(query, stmt.clone()) - } else { - None - } - } - - fn get_cached_stmt(&mut self, raw_query: &str) -> Option> { - self.stmt_cache_mut() - .by_query(raw_query) - .map(|entry| entry.stmt.clone()) - } - - fn read_packet<'a>(&'a mut self) -> ReadPacket<'a, Self> { - ReadPacket::new(self) - } - - /// Returns future that reads packets from a server. - fn read_packets<'a>(&'a mut self, n: usize) -> ReadPackets<'a, Self> { - ReadPackets::new(self, n) - } - - /// Returns future that reads result set from a server. - fn read_result_set<'a, P>(&'a mut self) -> BoxFuture<'a, QueryResult<'a, Self, P>> - where - Self: Sized, - P: Protocol, - { - BoxFuture(Box::pin(async move { - let packet = self.read_packet().await?; - match packet.get(0) { - Some(0x00) => Ok(QueryResult::new(self, None)), - Some(0xFB) => handle_local_infile(self, &*packet).await, - _ => handle_result_set(self, &*packet).await, - } - })) - } - - fn write_packet(&mut self, data: T) -> WritePacket<'_, Self> - where - T: Into>, - { - WritePacket::new(self, data.into()) - } - - /// Returns future that sends full command body to a server. - fn write_command_raw<'a>(&'a mut self, body: Vec) -> WritePacket<'a, Self> { - assert!(body.len() > 0); - self.reset_seq_id(); - self.write_packet(body) - } - - /// Returns future that writes command to a server. - fn write_command_data(&mut self, cmd: Command, cmd_data: T) -> WritePacket<'_, Self> - where - T: AsRef<[u8]>, - { - let cmd_data = cmd_data.as_ref(); - let mut body = Vec::with_capacity(1 + cmd_data.len()); - body.push(cmd as u8); - body.extend_from_slice(cmd_data); - self.write_command_raw(body) - } -} - -/// Will handle local infile packet. -async fn handle_local_infile<'a, T: ?Sized, P>( - this: &'a mut T, - packet: &[u8], -) -> Result> -where - P: Protocol, - T: ConnectionLike + Sized, -{ - let local_infile = parse_local_infile_packet(&*packet)?; - let (local_infile, handler) = match this.get_local_infile_handler() { - Some(handler) => ((local_infile.into_owned(), handler)), - None => return Err(DriverError::NoLocalInfileHandler.into()), - }; - let mut reader = handler.handle(local_infile.file_name_ref()).await?; - - let mut buf = [0; 4096]; - loop { - let read = reader.read(&mut buf[..]).await?; - this.write_packet(&buf[..read]).await?; - - if read == 0 { - break; - } - } - - this.read_packet().await?; - Ok(QueryResult::new(this, None)) -} - -/// Will handle result set packet. -async fn handle_result_set<'a, T: Sized, P>( - this: &'a mut T, - mut packet: &[u8], -) -> Result> -where - P: Protocol, - T: ConnectionLike, -{ - let column_count = packet.read_lenenc_int()?; - let packets = this.read_packets(column_count as usize).await?; - let columns = packets - .into_iter() - .map(|packet| column_from_payload(packet).map_err(Error::from)) - .collect::>>()?; - - if !this - .get_capabilities() - .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) - { - this.read_packet().await?; - } - - if column_count > 0 { - let columns = Arc::new(columns); - this.set_pending_result(Some(P::pending_result(columns.clone()))); - Ok(QueryResult::new(this, Some(columns))) - } else { - this.set_pending_result(Some(PendingResult::Empty)); - Ok(QueryResult::new(this, None)) - } } diff --git a/src/connection_like/read_packet.rs b/src/connection_like/read_packet.rs deleted file mode 100644 index cdafbcbe..00000000 --- a/src/connection_like/read_packet.rs +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2017 Anatoly Ikorsky -// -// Licensed under the Apache License, Version 2.0 -// or the MIT -// license , at your -// option. All files in the project carrying such notice may not be copied, -// modified, or distributed except according to those terms. - -use futures_core::{ready, stream::Stream}; -use mysql_common::packets::{parse_err_packet, parse_ok_packet, OkPacketKind}; - -use std::{ - future::Future, - mem, - pin::Pin, - task::{Context, Poll}, -}; - -use crate::{connection_like::ConnectionLike, consts::StatusFlags, error::*}; - -/// Reads some number of packets. -#[derive(Debug)] -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct ReadPackets<'a, T: ?Sized> { - conn_like: &'a mut T, - n: usize, - packets: Vec>, -} - -impl<'a, T: ?Sized> ReadPackets<'a, T> { - pub(crate) fn new(conn_like: &'a mut T, n: usize) -> Self { - Self { - conn_like, - n, - packets: Vec::with_capacity(n), - } - } -} - -impl<'a, T: ConnectionLike> Future for ReadPackets<'a, T> { - type Output = Result>>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - if self.n > 0 { - let packet_opt = - ready!(Pin::new(self.conn_like.stream_mut()).poll_next(cx)).transpose()?; - match packet_opt { - Some(packet) => { - let kind = if self.conn_like.get_pending_result().is_some() { - OkPacketKind::ResultSetTerminator - } else { - OkPacketKind::Other - }; - - if let Ok(ok_packet) = - parse_ok_packet(&*packet, self.conn_like.get_capabilities(), kind) - { - self.conn_like.set_status(ok_packet.status_flags()); - self.conn_like - .set_last_ok_packet(Some(ok_packet.into_owned())); - } else if let Ok(err_packet) = - parse_err_packet(&*packet, self.conn_like.get_capabilities()) - { - self.conn_like.set_status(StatusFlags::empty()); - self.conn_like.set_last_ok_packet(None); - return Err(err_packet.into()).into(); - } - - self.conn_like.touch(); - self.packets.push(packet); - self.n -= 1; - continue; - } - None => { - return Poll::Ready(Err(DriverError::ConnectionClosed.into())); - } - } - } else { - return Poll::Ready(Ok(mem::replace(&mut self.packets, Vec::new()))); - } - } - } -} - -pub struct ReadPacket<'a, T: ?Sized> { - inner: ReadPackets<'a, T>, -} - -impl<'a, T: ?Sized> ReadPacket<'a, T> { - pub(crate) fn new(conn_like: &'a mut T) -> Self { - Self { - inner: ReadPackets::new(conn_like, 1), - } - } -} - -impl<'a, T: ConnectionLike> Future for ReadPacket<'a, T> { - type Output = Result>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut packets = ready!(Pin::new(&mut self.inner).poll(cx))?; - Poll::Ready(Ok(packets.pop().unwrap())) - } -} diff --git a/src/error.rs b/src/error.rs index 3cfb4fab..43363dc5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -19,8 +19,6 @@ use std::{borrow::Cow, io, result}; /// Result type alias for this library. pub type Result = result::Result; -pub(crate) type StdResult = result::Result; - /// This type enumerates library errors. #[derive(Debug, Error)] pub enum Error { @@ -28,7 +26,7 @@ pub enum Error { Driver(#[source] DriverError), #[error("Input/output error: {}", _0)] - Io(#[source] io::Error), + Io(#[source] IoError), #[error("Other error: {}", _0)] Other(Cow<'static, str>), @@ -36,13 +34,20 @@ pub enum Error { #[error("Server error: `{}'", _0)] Server(#[source] ServerError), - #[error("TLS error: `{}'", _0)] - Tls(#[source] native_tls::Error), - #[error("URL error: `{}'", _0)] Url(#[source] UrlError), } +/// This type enumerates IO errors. +#[derive(Debug, Error)] +pub enum IoError { + #[error("Input/output error: {}", _0)] + Io(#[source] io::Error), + + #[error("TLS error: `{}'", _0)] + Tls(#[source] native_tls::Error), +} + /// This type represents MySql server error. #[derive(Debug, Error, Clone, Eq, PartialEq)] #[error("ERROR {} ({}): {}", state, code, message)] @@ -142,9 +147,21 @@ impl From for Error { } } +impl From for Error { + fn from(io: IoError) -> Self { + Error::Io(io) + } +} + +impl From for IoError { + fn from(err: io::Error) -> Self { + IoError::Io(err) + } +} + impl From for Error { fn from(err: io::Error) -> Self { - Error::Io(err) + Error::Io(err.into()) } } @@ -160,9 +177,9 @@ impl From for Error { } } -impl From for Error { +impl From for IoError { fn from(err: native_tls::Error) -> Self { - Error::Tls(err) + IoError::Tls(err) } } @@ -237,15 +254,25 @@ impl From for Error { } } -impl From for Error { +impl From for IoError { fn from(err: PacketCodecError) -> Self { match err { PacketCodecError::Io(err) => err.into(), - PacketCodecError::PacketTooLarge => DriverError::PacketTooLarge.into(), - PacketCodecError::PacketsOutOfSync => DriverError::PacketOutOfOrder.into(), + PacketCodecError::PacketTooLarge => { + io::Error::new(io::ErrorKind::Other, "packet too large").into() + } + PacketCodecError::PacketsOutOfSync => { + io::Error::new(io::ErrorKind::Other, "packet out of order").into() + } PacketCodecError::BadCompressedPacketHeader => { - DriverError::BadCompressedPacketHeader.into() + io::Error::new(io::ErrorKind::Other, "bad compressed packet header").into() } } } } + +impl From for Error { + fn from(err: PacketCodecError) -> Self { + Error::Io(err.into()) + } +} diff --git a/src/io/mod.rs b/src/io/mod.rs index 45af959a..10e8b785 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -6,6 +6,8 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. +pub use self::{read_packet::ReadPacket, write_packet::WritePacket}; + use bytes::{BufMut, BytesMut}; use futures_core::{ready, stream}; use futures_util::stream::{FuturesUnordered, StreamExt}; @@ -18,7 +20,12 @@ use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts}; use std::{ fmt, fs::File, - io::Read, + future::Future, + io::{ + self, + ErrorKind::{Other, UnexpectedEof}, + Read, + }, mem::MaybeUninit, net::ToSocketAddrs, ops::{Deref, DerefMut}, @@ -28,9 +35,11 @@ use std::{ time::Duration, }; -use crate::{error::*, io::socket::Socket, opts::SslOpts}; +use crate::{error::IoError, io::socket::Socket, opts::SslOpts}; +mod read_packet; mod socket; +mod write_packet; #[derive(Debug, Default)] pub struct PacketCodec(PacketCodecInner); @@ -51,18 +60,18 @@ impl DerefMut for PacketCodec { impl Decoder for PacketCodec { type Item = Vec; - type Error = Error; + type Error = IoError; - fn decode(&mut self, src: &mut BytesMut) -> Result> { + fn decode(&mut self, src: &mut BytesMut) -> std::result::Result, IoError> { Ok(self.0.decode(src)?) } } impl Encoder for PacketCodec { type Item = Vec; - type Error = Error; + type Error = IoError; - fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<()> { + fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> std::result::Result<(), IoError> { Ok(self.0.encode(item, dst)?) } } @@ -75,7 +84,45 @@ pub(crate) enum Endpoint { Socket(#[pin] Socket), } +/// This future will check that TcpStream is live. +/// +/// This check is similar to a one, implemented by GitHub team for the go-sql-driver/mysql. +#[derive(Debug)] +struct CheckTcpStream<'a>(&'a mut TcpStream); + +impl Future for CheckTcpStream<'_> { + type Output = io::Result<()>; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let buf = &mut [0_u8]; + match self.0.poll_peek(cx, buf) { + Poll::Ready(Ok(0)) => Poll::Ready(Err(io::Error::new(UnexpectedEof, "broken pipe"))), + Poll::Ready(Ok(_)) => Poll::Ready(Err(io::Error::new(Other, "unexpected read"))), + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => Poll::Ready(Ok(())), + } + } +} + impl Endpoint { + /// Checks, that connection is alive. + async fn check(&mut self) -> std::result::Result<(), IoError> { + match self { + Endpoint::Plain(Some(stream)) => { + CheckTcpStream(stream).await?; + Ok(()) + } + Endpoint::Secure(tls_stream) => { + CheckTcpStream(tls_stream.get_mut()).await?; + Ok(()) + } + Endpoint::Socket(socket) => { + socket.write(&[]).await?; + Ok(()) + } + Endpoint::Plain(None) => unreachable!(), + } + } + pub fn is_secure(&self) -> bool { if let Endpoint::Secure(_) = self { true @@ -84,7 +131,7 @@ impl Endpoint { } } - pub fn set_keepalive_ms(&self, ms: Option) -> Result<()> { + pub fn set_keepalive_ms(&self, ms: Option) -> io::Result<()> { let ms = ms.map(|val| Duration::from_millis(u64::from(val))); match *self { Endpoint::Plain(Some(ref stream)) => stream.set_keepalive(ms)?, @@ -95,7 +142,7 @@ impl Endpoint { Ok(()) } - pub fn set_tcp_nodelay(&self, val: bool) -> Result<()> { + pub fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> { match *self { Endpoint::Plain(Some(ref stream)) => stream.set_nodelay(val)?, Endpoint::Plain(None) => unreachable!(), @@ -105,7 +152,11 @@ impl Endpoint { Ok(()) } - pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { + pub async fn make_secure( + &mut self, + domain: String, + ssl_opts: SslOpts, + ) -> std::result::Result<(), IoError> { if let Endpoint::Socket(_) = self { // inapplicable return Ok(()); @@ -168,7 +219,7 @@ impl AsyncRead for Endpoint { self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8], - ) -> Poll> { + ) -> Poll> { #[project] match self.project() { Endpoint::Plain(ref mut stream) => { @@ -193,7 +244,7 @@ impl AsyncRead for Endpoint { self: Pin<&mut Self>, cx: &mut Context, buf: &mut B, - ) -> Poll> + ) -> Poll> where B: BufMut, { @@ -214,7 +265,7 @@ impl AsyncWrite for Endpoint { self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], - ) -> Poll> { + ) -> Poll> { #[project] match self.project() { Endpoint::Plain(ref mut stream) => { @@ -226,7 +277,10 @@ impl AsyncWrite for Endpoint { } #[project] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { #[project] match self.project() { Endpoint::Plain(ref mut stream) => Pin::new(stream.as_mut().unwrap()).poll_flush(cx), @@ -239,7 +293,7 @@ impl AsyncWrite for Endpoint { fn poll_shutdown( self: Pin<&mut Self>, cx: &mut Context, - ) -> Poll> { + ) -> Poll> { #[project] match self.project() { Endpoint::Plain(ref mut stream) => Pin::new(stream.as_mut().unwrap()).poll_shutdown(cx), @@ -275,7 +329,7 @@ impl Stream { } } - pub(crate) async fn connect_tcp(addr: S) -> Result + pub(crate) async fn connect_tcp(addr: S) -> io::Result where S: ToSocketAddrs, { @@ -317,19 +371,23 @@ impl Stream { } } - pub(crate) async fn connect_socket>(path: P) -> Result { + pub(crate) async fn connect_socket>(path: P) -> io::Result { Ok(Stream::new(Socket::new(path).await?)) } - pub(crate) fn set_keepalive_ms(&self, ms: Option) -> Result<()> { + pub(crate) fn set_keepalive_ms(&self, ms: Option) -> io::Result<()> { self.codec.as_ref().unwrap().get_ref().set_keepalive_ms(ms) } - pub(crate) fn set_tcp_nodelay(&self, val: bool) -> Result<()> { + pub(crate) fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> { self.codec.as_ref().unwrap().get_ref().set_tcp_nodelay(val) } - pub(crate) async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { + pub(crate) async fn make_secure( + &mut self, + domain: String, + ssl_opts: SslOpts, + ) -> crate::error::Result<()> { let codec = self.codec.take().unwrap(); let FramedParts { mut io, codec, .. } = codec.into_parts(); io.make_secure(domain, ssl_opts).await?; @@ -366,7 +424,15 @@ impl Stream { } } - pub(crate) async fn close(mut self) -> Result<()> { + /// Checks, that connection is alive. + pub(crate) async fn check(&mut self) -> std::result::Result<(), IoError> { + if let Some(codec) = self.codec.as_mut() { + codec.get_mut().check().await?; + } + Ok(()) + } + + pub(crate) async fn close(mut self) -> std::result::Result<(), IoError> { self.closed = true; if let Some(mut codec) = self.codec { use futures_sink::Sink; @@ -377,7 +443,7 @@ impl Stream { } impl stream::Stream for Stream { - type Item = Result>; + type Item = std::result::Result, IoError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if !self.closed { diff --git a/src/io/read_packet.rs b/src/io/read_packet.rs new file mode 100644 index 00000000..8996f1d8 --- /dev/null +++ b/src/io/read_packet.rs @@ -0,0 +1,54 @@ +// Copyright (c) 2017 Anatoly Ikorsky +// +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , at your +// option. All files in the project carrying such notice may not be copied, +// modified, or distributed except according to those terms. + +use futures_core::{ready, stream::Stream}; + +use std::{ + future::Future, + io::{Error, ErrorKind}, + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{connection_like::ConnectionLike, error::IoError}; + +/// Reads a packet. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct ReadPacket<'a, T: ?Sized> { + conn_like: &'a mut T, +} + +impl<'a, T: ?Sized> ReadPacket<'a, T> { + pub(crate) fn new(conn_like: &'a mut T) -> Self { + Self { conn_like } + } +} + +impl<'a, T: ConnectionLike> Future for ReadPacket<'a, T> { + type Output = std::result::Result, IoError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let packet_opt = + ready!(Pin::new(self.conn_like.conn_mut().stream_mut()).poll_next(cx)).transpose()?; + + match packet_opt { + Some(packet) => { + self.conn_like.conn_mut().touch(); + return Poll::Ready(Ok(packet)); + } + None => { + return Poll::Ready(Err(Error::new( + ErrorKind::UnexpectedEof, + "connection closed", + ) + .into())); + } + } + } +} diff --git a/src/io/socket.rs b/src/io/socket.rs index bd5efb4e..45881e64 100644 --- a/src/io/socket.rs +++ b/src/io/socket.rs @@ -8,13 +8,15 @@ use bytes::BufMut; use pin_project::pin_project; +use tokio::{io::Error, prelude::*}; + use std::{ + io, + mem::MaybeUninit, + path::Path, pin::Pin, task::{Context, Poll}, }; -use tokio::{io::Error, prelude::*}; - -use std::{io, mem::MaybeUninit, path::Path}; /// Unix domain socket connection on unix, or named pipe connection on windows. #[pin_project] diff --git a/src/connection_like/write_packet.rs b/src/io/write_packet.rs similarity index 88% rename from src/connection_like/write_packet.rs rename to src/io/write_packet.rs index ae04cf46..5a3546cf 100644 --- a/src/connection_like/write_packet.rs +++ b/src/io/write_packet.rs @@ -15,7 +15,7 @@ use std::{ task::{Context, Poll}, }; -use crate::{connection_like::ConnectionLike, error::*}; +use crate::{connection_like::ConnectionLike, error::IoError}; pub struct WritePacket<'a, T: ?Sized> { conn_like: &'a mut T, @@ -35,12 +35,13 @@ impl<'a, T> Future for WritePacket<'a, T> where T: ConnectionLike, { - type Output = Result<()>; + type Output = std::result::Result<(), IoError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if self.data.is_some() { let codec = Pin::new( self.conn_like + .conn_mut() .stream_mut() .codec .as_mut() @@ -52,6 +53,7 @@ where if let Some(data) = self.data.take() { let codec = Pin::new( self.conn_like + .conn_mut() .stream_mut() .codec .as_mut() @@ -63,13 +65,14 @@ where let codec = Pin::new( self.conn_like + .conn_mut() .stream_mut() .codec .as_mut() .expect("must be here"), ); - ready!(codec.poll_flush(cx)).map_err(Error::from)?; + ready!(codec.poll_flush(cx))?; Poll::Ready(Ok(())) } diff --git a/src/lib.rs b/src/lib.rs index 6e719f03..7a18fc4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -128,6 +128,17 @@ impl Future for BoxFuture<'_, T> { } } +impl<'a, T> std::fmt::Debug for BoxFuture<'a, T> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("BoxFuture") + .field(&format!( + "dyn Future", + std::any::type_name::() + )) + .finish() + } +} + #[doc(inline)] pub use self::conn::Conn; diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index 15900fbf..c17c702b 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -20,12 +20,14 @@ use self::{ transaction::{Transaction, TxStatus}, }; use crate::{ - conn::PendingResult, connection_like::ConnectionLike, consts::Command, error::*, prelude::FromRow, - queryable::stmt::{close_statement, execute_statement, get_statement, StatementLike}, + queryable::{ + query_result::{read_result_set, ResultSetMeta}, + stmt::StatementLike, + }, BoxFuture, Column, Conn, Params, Row, }; @@ -34,7 +36,8 @@ pub mod stmt; pub mod transaction; pub trait Protocol: Send + 'static { - fn pending_result(columns: Arc>) -> PendingResult; + /// Returns `ResultSetMeta`, that corresponds to the current protocol. + fn result_set_meta(columns: Arc>) -> ResultSetMeta; fn read_result_set_row(packet: &[u8], columns: Arc>) -> Result; fn is_last_result_set_packet(conn_like: &T, packet: &[u8]) -> bool where @@ -42,7 +45,7 @@ pub trait Protocol: Send + 'static { { parse_ok_packet( packet, - conn_like.get_capabilities(), + conn_like.conn_ref().capabilities(), OkPacketKind::ResultSetTerminator, ) .is_ok() @@ -56,8 +59,8 @@ pub struct TextProtocol; pub struct BinaryProtocol; impl Protocol for TextProtocol { - fn pending_result(columns: Arc>) -> PendingResult { - PendingResult::Text(columns) + fn result_set_meta(columns: Arc>) -> ResultSetMeta { + ResultSetMeta::Text(columns) } fn read_result_set_row(packet: &[u8], columns: Arc>) -> Result { @@ -68,8 +71,8 @@ impl Protocol for TextProtocol { } impl Protocol for BinaryProtocol { - fn pending_result(columns: Arc>) -> PendingResult { - PendingResult::Binary(columns) + fn result_set_meta(columns: Arc>) -> ResultSetMeta { + ResultSetMeta::Binary(columns) } fn read_result_set_row(packet: &[u8], columns: Arc>) -> Result { @@ -83,8 +86,8 @@ impl Protocol for BinaryProtocol { /// where `Transaction` is dropped without an explicit call to `commit` or `rollback`. async fn cleanup(queryable: &mut T) -> Result<()> { queryable.conn_mut().drop_result().await?; - if queryable.get_tx_status() == TxStatus::RequiresRollback { - queryable.set_tx_status(TxStatus::None); + if queryable.conn_ref().get_tx_status() == TxStatus::RequiresRollback { + queryable.conn_mut().set_tx_status(TxStatus::None); queryable.exec_drop("ROLLBACK", ()).await?; } Ok(()) @@ -92,12 +95,13 @@ async fn cleanup(queryable: &mut T) -> Result<()> { pub trait Queryable: crate::prelude::ConnectionLike { /// Executes `COM_PING`. - fn ping<'a>(&'a mut self) -> BoxFuture<'a, ()> { + fn ping(&mut self) -> BoxFuture<'_, ()> { BoxFuture(Box::pin(async move { cleanup(self).await?; - self.write_command_raw(vec![Command::COM_PING as u8]) + self.conn_mut() + .write_command_raw(vec![Command::COM_PING as u8]) .await?; - self.read_packet().await?; + self.conn_mut().read_packet().await?; Ok(()) })) } @@ -112,9 +116,10 @@ pub trait Queryable: crate::prelude::ConnectionLike { { BoxFuture(Box::pin(async move { cleanup(self).await?; - self.write_command_data(Command::COM_QUERY, query.as_ref().as_bytes()) + self.conn_mut() + .write_command_data(Command::COM_QUERY, query.as_ref().as_bytes()) .await?; - self.read_result_set().await + read_result_set(self).await })) } @@ -125,7 +130,7 @@ pub trait Queryable: crate::prelude::ConnectionLike { { BoxFuture(Box::pin(async move { cleanup(self).await?; - get_statement(self, query.as_ref()).await + self.conn_mut().get_statement(query.as_ref()).await })) } @@ -133,8 +138,8 @@ pub trait Queryable: crate::prelude::ConnectionLike { fn close(&mut self, stmt: Statement) -> BoxFuture<'_, ()> { BoxFuture(Box::pin(async move { cleanup(self).await?; - self.stmt_cache_mut().remove(stmt.id()); - close_statement(self, stmt.id()).await + self.conn_mut().stmt_cache_mut().remove(stmt.id()); + self.conn_mut().close_statement(stmt.id()).await })) } @@ -143,7 +148,7 @@ pub trait Queryable: crate::prelude::ConnectionLike { &'a mut self, stmt: &'b Q, params: P, - ) -> BoxFuture<'b, QueryResult<'a, Self, BinaryProtocol>> + ) -> BoxFuture<'b, QueryResult<'a, crate::Conn, BinaryProtocol>> where Q: StatementLike + ?Sized + 'a, P: Into, @@ -151,8 +156,8 @@ pub trait Queryable: crate::prelude::ConnectionLike { let params = params.into(); BoxFuture(Box::pin(async move { cleanup(self).await?; - let statement = get_statement(self, stmt).await?; - execute_statement(self, &statement, params).await + let statement = self.conn_mut().get_statement(stmt).await?; + self.conn_mut().execute_statement(&statement, params).await })) } @@ -242,9 +247,10 @@ pub trait Queryable: crate::prelude::ConnectionLike { { BoxFuture(Box::pin(async move { cleanup(self).await?; - let statement = get_statement(self, stmt).await?; + let statement = self.conn_mut().get_statement(stmt).await?; for params in params_iter { - execute_statement(self, &statement, params) + self.conn_mut() + .execute_statement(&statement, params) .await? .drop_result() .await?; diff --git a/src/queryable/query_result/mod.rs b/src/queryable/query_result/mod.rs index bc1b3ab5..9eb6a36a 100644 --- a/src/queryable/query_result/mod.rs +++ b/src/queryable/query_result/mod.rs @@ -7,46 +7,52 @@ // modified, or distributed except according to those terms. use mysql_common::row::convert::FromRowError; +use mysql_common::{ + io::ReadMysqlExt, + packets::{column_from_payload, parse_local_infile_packet}, +}; +use tokio::prelude::*; use std::{borrow::Cow, marker::PhantomData, result::Result as StdResult, sync::Arc}; -use self::QueryResultInner::*; use crate::{ - consts::StatusFlags, + consts::{CapabilityFlags, StatusFlags}, error::*, prelude::{ConnectionLike, FromRow, Protocol}, Column, Row, }; -enum QueryResultInner { +/// Result set metadata. +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum ResultSetMeta { + /// Text result set, that may contain rows. + Text(Arc>), + /// Binary result set, that may contain rows. + Binary(Arc>), + /// Result with no rows. Empty, - WithRows(Arc>), } -impl QueryResultInner { - fn new(columns: Option>>) -> Self { - match columns { - Some(columns) => WithRows(columns), - None => Empty, - } - } - +impl ResultSetMeta { fn columns(&self) -> Option<&Arc>> { match self { - WithRows(columns) => Some(columns), - Empty => None, + ResultSetMeta::Text(columns) | ResultSetMeta::Binary(columns) => Some(columns), + ResultSetMeta::Empty => None, } } - fn make_empty(&mut self) { - *self = Empty + fn into_columns(self) -> Option>> { + match self { + ResultSetMeta::Text(columns) | ResultSetMeta::Binary(columns) => Some(columns), + ResultSetMeta::Empty => None, + } } } /// Result of a query or statement execution. pub struct QueryResult<'a, T: ?Sized, P> { conn_like: &'a mut T, - inner: QueryResultInner, + meta: ResultSetMeta, __phantom: PhantomData

, } @@ -55,27 +61,21 @@ where P: Protocol, T: ConnectionLike, { - pub(crate) fn new( - conn_like: &'a mut T, - columns: Option>>, - ) -> QueryResult<'a, T, P> { + pub(crate) fn new(conn_like: &'a mut T, meta: ResultSetMeta) -> QueryResult<'a, T, P> { QueryResult { conn_like, - inner: QueryResultInner::new(columns), + meta, __phantom: PhantomData, } } pub(crate) fn disassemble(self) -> (&'a mut T, Option>>) { - match self.inner { - WithRows(columns) => (self.conn_like, Some(columns)), - Empty => (self.conn_like, None), - } + (self.conn_like, self.meta.into_columns()) } fn make_empty(&mut self) { - self.conn_like.set_pending_result(None); - self.inner.make_empty(); + self.conn_like.conn_mut().set_pending_result(None); + self.meta = ResultSetMeta::Empty; } async fn get_row_raw(&mut self) -> Result>> { @@ -83,13 +83,13 @@ where return Ok(None); } - let packet: Vec = self.conn_like.read_packet().await?; + let packet: Vec = self.conn_like.conn_mut().read_packet().await?; if P::is_last_result_set_packet(&*self.conn_like, &packet) { if self.more_results_exists() { - self.conn_like.sync_seq_id(); - let next_set = self.conn_like.read_result_set::

().await?; - self.inner = next_set.inner; + self.conn_like.conn_mut().sync_seq_id(); + let next_set = read_result_set::<_, P>(self.conn_like.conn_mut()).await?; + self.meta = next_set.meta; Ok(None) } else { self.make_empty(); @@ -102,11 +102,11 @@ where /// Returns next row, if any. /// - /// Requires that `self.inner` matches `WithRows(..)`. + /// Requires that `self.meta` is not `Empty`. pub(crate) async fn get_row(&mut self) -> Result> { let packet = self.get_row_raw().await?; if let Some(packet) = packet { - let columns = self.inner.columns().expect("must be here"); + let columns = self.meta.columns().expect("must be here"); let row = P::read_result_set_row(&packet, columns.clone())?; Ok(Some(row)) } else { @@ -116,22 +116,22 @@ where /// Last insert id, if any. pub fn last_insert_id(&self) -> Option { - self.conn_like.get_last_insert_id() + self.conn_like.conn_ref().last_insert_id() } /// Number of affected rows, as reported by the server, or `0`. pub fn affected_rows(&self) -> u64 { - self.conn_like.get_affected_rows() + self.conn_like.conn_ref().affected_rows() } /// Text information, as reported by the server, or an empty string. pub fn info(&self) -> Cow<'_, str> { - self.conn_like.get_info() + self.conn_like.conn_ref().info() } /// Number of warnings, as reported by the server, or `0`. pub fn warnings(&self) -> u16 { - self.conn_like.get_warnings() + self.conn_like.conn_ref().get_warnings() } /// `true` if there is no more rows nor result sets in this query. @@ -145,7 +145,8 @@ where /// of the connection. fn more_results_exists(&self) -> bool { self.conn_like - .get_status() + .conn_ref() + .status() .contains(StatusFlags::SERVER_MORE_RESULTS_EXISTS) } @@ -153,7 +154,7 @@ where /// /// If `false` then there is no rows possible (for example UPDATE query). fn has_rows(&self) -> bool { - matches!(self.inner, WithRows(..)) + !matches!(self.meta, ResultSetMeta::Empty) } /// Returns a future that collects result set of this query result. @@ -170,7 +171,7 @@ where /// It'll panic if any row isn't convertible to `R` (i.e. programmer error or unknown schema). /// * In case of programmer error see [`FromRow`] docs; /// * In case of unknown schema use [`QueryResult::try_collect`]. - pub async fn collect<'b, R>(&mut self) -> Result> + pub async fn collect(&mut self) -> Result> where R: FromRow + Send + 'static, { @@ -204,7 +205,7 @@ where /// It'll panic if any row isn't convertible to `R` (i.e. programmer error or unknown schema). /// * In case of programmer error see `FromRow` docs; /// * In case of unknown schema use [`QueryResult::try_collect`]. - pub async fn collect_and_drop<'b, R>(mut self) -> Result> + pub async fn collect_and_drop(mut self) -> Result> where R: FromRow + Send + 'static, { @@ -331,7 +332,7 @@ where if !self.has_rows() { if self.more_results_exists() { let (inner, _) = self.disassemble(); - self = inner.read_result_set().await?; + self = read_result_set(inner).await?; } else { break; } @@ -344,8 +345,10 @@ where } /// Returns a reference to a columns list of this query result. + /// + /// Empty list means, that this result set was never meant to contain rows. pub fn columns_ref(&self) -> &[Column] { - self.inner + self.meta .columns() .map(|columns| &***columns) .unwrap_or_default() @@ -353,6 +356,86 @@ where /// Returns a copy of a columns list of this query result. pub fn columns(&self) -> Option>> { - self.inner.columns().cloned() + self.meta.columns().cloned() + } +} + +/// Helper, that reads result set from a server. +pub(crate) async fn read_result_set<'a, T, P>(conn_like: &'a mut T) -> Result> +where + T: ConnectionLike, + P: Protocol, +{ + let packet = conn_like.conn_mut().read_packet().await?; + match packet.get(0) { + Some(0x00) => Ok(QueryResult::new(conn_like, ResultSetMeta::Empty)), + Some(0xFB) => handle_local_infile(conn_like, &*packet).await, + _ => handle_result_set(conn_like, &*packet).await, + } +} + +/// Helper that handles local infile packet. +pub(crate) async fn handle_local_infile<'a, T: ?Sized, P>( + this: &'a mut T, + packet: &[u8], +) -> Result> +where + P: Protocol, + T: ConnectionLike + Sized, +{ + let local_infile = parse_local_infile_packet(&*packet)?; + let (local_infile, handler) = match this.conn_mut().opts().get_local_infile_handler() { + Some(handler) => ((local_infile.into_owned(), handler)), + None => return Err(DriverError::NoLocalInfileHandler.into()), + }; + let mut reader = handler.handle(local_infile.file_name_ref()).await?; + + let mut buf = [0; 4096]; + loop { + let read = reader.read(&mut buf[..]).await?; + this.conn_mut().write_packet(&buf[..read]).await?; + + if read == 0 { + break; + } + } + + this.conn_mut().read_packet().await?; + Ok(QueryResult::new(this, ResultSetMeta::Empty)) +} + +/// Helper that handles result set packet. +pub(crate) async fn handle_result_set<'a, T: Sized, P>( + this: &'a mut T, + mut packet: &[u8], +) -> Result> +where + P: Protocol, + T: ConnectionLike, +{ + let column_count = packet.read_lenenc_int()?; + let packets = this.conn_mut().read_packets(column_count as usize).await?; + let columns = packets + .into_iter() + .map(|packet| column_from_payload(packet).map_err(Error::from)) + .collect::>>()?; + + if !this + .conn_ref() + .capabilities() + .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) + { + this.conn_mut().read_packet().await?; + } + + if column_count > 0 { + let columns = Arc::new(columns); + let meta = P::result_set_meta(columns.clone()); + this.conn_mut().set_pending_result(Some(meta.clone())); + Ok(QueryResult::new(this, meta)) + } else { + this.conn_mut() + .set_pending_result(Some(ResultSetMeta::Empty)); + Ok(QueryResult::new(this, ResultSetMeta::Empty)) } } diff --git a/src/queryable/stmt.rs b/src/queryable/stmt.rs index 02d9b7a1..0c0146bd 100644 --- a/src/queryable/stmt.rs +++ b/src/queryable/stmt.rs @@ -21,169 +21,13 @@ use crate::{ connection_like::ConnectionLike, consts::{CapabilityFlags, Command}, error::*, - queryable::{query_result::QueryResult, BinaryProtocol}, + queryable::{ + query_result::{read_result_set, QueryResult}, + BinaryProtocol, + }, Column, Params, Value, }; -pub(crate) async fn get_statement(conn_like: &mut T, stmt_like: &U) -> Result -where - T: ConnectionLike, - U: StatementLike + ?Sized, -{ - let (named_params, raw_query) = stmt_like.info()?; - let stmt_inner = if let Some(stmt_inner) = conn_like.get_cached_stmt(raw_query.as_ref()) { - stmt_inner - } else { - prepare_statement(conn_like, raw_query).await? - }; - Ok(Statement::new(stmt_inner, named_params)) -} - -pub(crate) async fn prepare_statement<'a, T>( - conn_like: &'a mut T, - raw_query: Cow<'_, str>, -) -> Result> -where - T: ConnectionLike, -{ - /// Requires `num > 0`. - async fn read_column_defs(conn_like: &mut T, num: U) -> Result> - where - T: ConnectionLike, - U: Into, - { - let num = num.into(); - debug_assert!(num > 0); - let packets = conn_like.read_packets(num).await?; - let defs = packets - .into_iter() - .map(column_from_payload) - .collect::, _>>() - .map_err(Error::from)?; - - if !conn_like - .get_capabilities() - .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) - { - conn_like.read_packet().await?; - } - - Ok(defs) - } - - let raw_query: Arc = raw_query.into_owned().into_boxed_str().into(); - - conn_like - .write_command_data(Command::COM_STMT_PREPARE, raw_query.as_bytes()) - .await?; - - let packet = conn_like.read_packet().await?; - let mut inner_stmt = StmtInner::from_payload(&*packet, conn_like.connection_id(), raw_query)?; - - if inner_stmt.num_params() > 0 { - let params = read_column_defs(conn_like, inner_stmt.num_params()).await?; - inner_stmt = inner_stmt.with_params(params); - } - - if inner_stmt.num_columns() > 0 { - let columns = read_column_defs(conn_like, inner_stmt.num_columns()).await?; - inner_stmt = inner_stmt.with_columns(columns); - } - - let inner_stmt = Arc::new(inner_stmt); - - if let Some(old_stmt) = conn_like.cache_stmt(&inner_stmt) { - close_statement(conn_like, old_stmt.id()).await?; - } - - Ok(inner_stmt) -} - -pub(crate) async fn execute_statement<'a, T, P>( - conn_like: &'a mut T, - statement: &Statement, - params: P, -) -> Result> -where - T: ConnectionLike, - P: Into, -{ - let mut params = params.into(); - loop { - match params { - Params::Positional(params) => { - if statement.num_params() as usize != params.len() { - Err(DriverError::StmtParamsMismatch { - required: statement.num_params(), - supplied: params.len() as u16, - })? - } - - let params = params.into_iter().collect::>(); - - let (body, as_long_data) = - ComStmtExecuteRequestBuilder::new(statement.id()).build(&*params); - - if as_long_data { - for (i, value) in params.into_iter().enumerate() { - if let Value::Bytes(bytes) = value { - let chunks = bytes.chunks(MAX_PAYLOAD_LEN - 6); - let chunks = chunks.chain(if bytes.is_empty() { - Some(&[][..]) - } else { - None - }); - for chunk in chunks { - let com = ComStmtSendLongData::new(statement.id(), i, chunk); - conn_like.write_command_raw(com.into()).await?; - } - } - } - } - - conn_like.write_command_raw(body).await?; - break conn_like.read_result_set().await; - } - Params::Named(_) => { - if statement.named_params.is_none() { - let error = DriverError::NamedParamsForPositionalQuery.into(); - return Err(error); - } - - params = match params.into_positional(statement.named_params.as_ref().unwrap()) { - Ok(positional_params) => positional_params, - Err(error) => return Err(error.into()), - }; - - continue; - } - Params::Empty => { - if statement.num_params() > 0 { - let error = DriverError::StmtParamsMismatch { - required: statement.num_params(), - supplied: 0, - } - .into(); - return Err(error); - } - - let (body, _) = ComStmtExecuteRequestBuilder::new(statement.id()).build(&[]); - conn_like.write_command_raw(body).await?; - break conn_like.read_result_set().await; - } - } - } -} - -pub(crate) async fn close_statement(conn_like: &mut T, id: u32) -> Result<()> -where - T: ConnectionLike, -{ - conn_like - .write_command_raw(ComStmtClose::new(id).into()) - .await -} - pub trait StatementLike: Send + Sync { /// Returns raw statement query coupled with its nemed parameters. fn info(&self) -> Result<(Option>, Cow)>; @@ -312,3 +156,169 @@ impl Statement { self.inner.num_columns() } } + +impl crate::Conn { + /// Low-level helpers, that reads the given number of column packets from server. + /// + /// Requires `num > 0`. + async fn read_column_defs(&mut self, num: U) -> Result> + where + U: Into, + { + let num = num.into(); + debug_assert!(num > 0); + let packets = self.read_packets(num).await?; + let defs = packets + .into_iter() + .map(column_from_payload) + .collect::, _>>() + .map_err(Error::from)?; + + if !self + .conn_ref() + .capabilities() + .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) + { + self.read_packet().await?; + } + + Ok(defs) + } + + /// Helper, that retrieves `Statement` from `StatementLike`. + pub(crate) async fn get_statement(&mut self, stmt_like: &U) -> Result + where + U: StatementLike + ?Sized, + { + let (named_params, raw_query) = stmt_like.info()?; + let stmt_inner = if let Some(stmt_inner) = self.get_cached_stmt(raw_query.as_ref()) { + stmt_inner + } else { + self.prepare_statement(raw_query).await? + }; + Ok(Statement::new(stmt_inner, named_params)) + } + + /// Low-level helper, that prepares the given statement. + /// + /// `raw_query` is a query with `?` placeholders (if any). + async fn prepare_statement(&mut self, raw_query: Cow<'_, str>) -> Result> { + let raw_query: Arc = raw_query.into_owned().into_boxed_str().into(); + + self.write_command_data(Command::COM_STMT_PREPARE, raw_query.as_bytes()) + .await?; + + let packet = self.read_packet().await?; + let mut inner_stmt = StmtInner::from_payload(&*packet, self.conn_ref().id(), raw_query)?; + + if inner_stmt.num_params() > 0 { + let params = self.read_column_defs(inner_stmt.num_params()).await?; + inner_stmt = inner_stmt.with_params(params); + } + + if inner_stmt.num_columns() > 0 { + let columns = self.read_column_defs(inner_stmt.num_columns()).await?; + inner_stmt = inner_stmt.with_columns(columns); + } + + let inner_stmt = Arc::new(inner_stmt); + + if let Some(old_stmt) = self.conn_mut().cache_stmt(&inner_stmt) { + self.close_statement(old_stmt.id()).await?; + } + + Ok(inner_stmt) + } + + /// Helper, that executes the given statement with the given params. + pub(crate) async fn execute_statement

( + &mut self, + statement: &Statement, + params: P, + ) -> Result> + where + P: Into, + { + let mut params = params.into(); + loop { + match params { + Params::Positional(params) => { + if statement.num_params() as usize != params.len() { + Err(DriverError::StmtParamsMismatch { + required: statement.num_params(), + supplied: params.len() as u16, + })? + } + + let params = params.into_iter().collect::>(); + + let (body, as_long_data) = + ComStmtExecuteRequestBuilder::new(statement.id()).build(&*params); + + if as_long_data { + self.send_long_data(statement.id(), params.iter()).await?; + } + + self.write_command_raw(body).await?; + break read_result_set(self).await; + } + Params::Named(_) => { + if statement.named_params.is_none() { + let error = DriverError::NamedParamsForPositionalQuery.into(); + return Err(error); + } + + params = match params.into_positional(statement.named_params.as_ref().unwrap()) + { + Ok(positional_params) => positional_params, + Err(error) => return Err(error.into()), + }; + + continue; + } + Params::Empty => { + if statement.num_params() > 0 { + let error = DriverError::StmtParamsMismatch { + required: statement.num_params(), + supplied: 0, + } + .into(); + return Err(error); + } + + let (body, _) = ComStmtExecuteRequestBuilder::new(statement.id()).build(&[]); + self.write_command_raw(body).await?; + break read_result_set(self).await; + } + } + } + } + + /// Helper, that sends all `Value::Bytes` in the given list of paramenters as long data. + async fn send_long_data<'a, I>(&mut self, statement_id: u32, params: I) -> Result<()> + where + I: Iterator, + { + for (i, value) in params.enumerate() { + if let Value::Bytes(bytes) = value { + let chunks = bytes.chunks(MAX_PAYLOAD_LEN - 6); + let chunks = chunks.chain(if bytes.is_empty() { + Some(&[][..]) + } else { + None + }); + for chunk in chunks { + let com = ComStmtSendLongData::new(statement_id, i, chunk); + self.write_command_raw(com.into()).await?; + } + } + } + + Ok(()) + } + + /// Helper, that closes statement with the given id. + pub(crate) async fn close_statement(&mut self, id: u32) -> Result<()> { + self.write_command_raw(ComStmtClose::new(id).into()).await + } +} diff --git a/src/queryable/transaction.rs b/src/queryable/transaction.rs index 1b0430aa..70f9f5a9 100644 --- a/src/queryable/transaction.rs +++ b/src/queryable/transaction.rs @@ -129,11 +129,11 @@ impl<'a, T: Queryable> Transaction<'a, T> { readonly, } = options; - if conn_like.get_tx_status() != TxStatus::None { + if conn_like.conn_ref().get_tx_status() != TxStatus::None { return Err(DriverError::NestedTransaction.into()); } - if readonly.is_some() && conn_like.get_server_version() < (5, 6, 5) { + if readonly.is_some() && conn_like.conn_ref().server_version() < (5, 6, 5) { return Err(DriverError::ReadOnlyTransNotSupported.into()); } @@ -158,7 +158,7 @@ impl<'a, T: Queryable> Transaction<'a, T> { conn_like.query_drop("START TRANSACTION").await? }; - conn_like.set_tx_status(TxStatus::InTransaction); + conn_like.conn_mut().set_tx_status(TxStatus::InTransaction); Ok(Transaction(conn_like)) } @@ -166,7 +166,7 @@ impl<'a, T: Queryable> Transaction<'a, T> { pub async fn commit(mut self) -> Result<()> { let result = self.0.query_iter("COMMIT").await?; result.drop_result().await?; - self.set_tx_status(TxStatus::None); + self.conn_mut().set_tx_status(TxStatus::None); Ok(()) } @@ -174,96 +174,26 @@ impl<'a, T: Queryable> Transaction<'a, T> { pub async fn rollback(mut self) -> Result<()> { let result = self.0.query_iter("ROLLBACK").await?; result.drop_result().await?; - self.set_tx_status(TxStatus::None); + self.conn_mut().set_tx_status(TxStatus::None); Ok(()) } } impl Drop for Transaction<'_, T> { fn drop(&mut self) { - if self.get_tx_status() == TxStatus::InTransaction { - self.set_tx_status(TxStatus::RequiresRollback); + let conn = self.conn_mut(); + if conn.get_tx_status() == TxStatus::InTransaction { + conn.set_tx_status(TxStatus::RequiresRollback); } } } impl<'a, T: ConnectionLike> ConnectionLike for Transaction<'a, T> { - fn conn_mut(&mut self) -> &mut crate::Conn { - self.0.conn_mut() + fn conn_ref(&self) -> &crate::Conn { + self.0.conn_ref() } - fn stream_mut(&mut self) -> &mut crate::io::Stream { - self.0.stream_mut() - } - fn get_affected_rows(&self) -> u64 { - self.0.get_affected_rows() - } - fn get_capabilities(&self) -> crate::consts::CapabilityFlags { - self.0.get_capabilities() - } - fn get_tx_status(&self) -> TxStatus { - self.0.get_tx_status() - } - fn get_last_insert_id(&self) -> Option { - self.0.get_last_insert_id() - } - fn get_info(&self) -> std::borrow::Cow<'_, str> { - self.0.get_info() - } - fn get_warnings(&self) -> u16 { - self.0.get_warnings() - } - fn get_local_infile_handler( - &self, - ) -> Option> { - self.0.get_local_infile_handler() - } - fn get_max_allowed_packet(&self) -> usize { - self.0.get_max_allowed_packet() - } - fn get_opts(&self) -> &crate::Opts { - self.0.get_opts() - } - fn get_pending_result(&self) -> Option<&crate::conn::PendingResult> { - self.0.get_pending_result() - } - fn get_server_version(&self) -> (u16, u16, u16) { - self.0.get_server_version() - } - fn get_status(&self) -> crate::consts::StatusFlags { - self.0.get_status() - } - fn set_last_ok_packet(&mut self, ok_packet: Option>) { - self.0.set_last_ok_packet(ok_packet) - } - fn set_tx_status(&mut self, tx_status: TxStatus) { - self.0.set_tx_status(tx_status) - } - fn set_pending_result(&mut self, meta: Option) { - self.0.set_pending_result(meta) - } - fn set_status(&mut self, status: crate::consts::StatusFlags) { - self.0.set_status(status) - } - fn reset_seq_id(&mut self) { - self.0.reset_seq_id() - } - fn sync_seq_id(&mut self) { - self.0.sync_seq_id() - } - fn touch(&mut self) -> () { - self.0.touch() - } - fn on_disconnect(&mut self) { - self.0.on_disconnect() - } - fn connection_id(&self) -> u32 { - self.0.connection_id() - } - fn stmt_cache_ref(&self) -> &crate::conn::stmt_cache::StmtCache { - self.0.stmt_cache_ref() - } - fn stmt_cache_mut(&mut self) -> &mut crate::conn::stmt_cache::StmtCache { - self.0.stmt_cache_mut() + fn conn_mut(&mut self) -> &mut crate::Conn { + self.0.conn_mut() } }