diff --git a/Cargo.toml b/Cargo.toml index a4345a72..3cbbb96c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,20 +7,23 @@ license = "MIT/Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.27.1" +version = "0.28.0" exclude = ["test/*"] edition = "2018" +categories = ["asynchronous", "database"] [dependencies] bytes = "1.0" +flate2 = { version = "1.0", default-features = false } futures-core = "0.3" futures-util = "0.3" futures-sink = "0.3" lazy_static = "1" lru = "0.6.0" mio = "0.7.7" -mysql_common = "0.26.0" +mysql_common = { version = "0.27.2", default-features = false } native-tls = "0.2" +once_cell = "1.7.2" pem = "0.8.1" percent-encoding = "2.1.0" pin-project = "1.0.2" @@ -36,11 +39,20 @@ uuid = { version = "0.8.1", features = ["v4"] } [dev-dependencies] tempfile = "3.1.0" -socket2 = "0.3.17" +socket2 = { version = "0.4.0", features = ["all"] } tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } rand = "0.8.0" [features] +default = [ + "flate2/zlib", + "mysql_common/bigdecimal", + "mysql_common/chrono", + "mysql_common/rust_decimal", + "mysql_common/time", + "mysql_common/uuid", + "mysql_common/frunk", +] nightly = [] [lib] diff --git a/azure-pipelines.yml b/azure-pipelines.yml index c4ac5387..62f453ed 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -21,6 +21,12 @@ jobs: sudo apt-get -y install mysql-server libmysqlclient-dev curl sudo service mysql start mysql -e "SET GLOBAL max_allowed_packet = 36700160;" -uroot -proot + mysql -e "SET @@GLOBAL.ENFORCE_GTID_CONSISTENCY = WARN;" -uroot -proot + mysql -e "SET @@GLOBAL.ENFORCE_GTID_CONSISTENCY = ON;" -uroot -proot + mysql -e "SET @@GLOBAL.GTID_MODE = OFF_PERMISSIVE;" -uroot -proot + mysql -e "SET @@GLOBAL.GTID_MODE = ON_PERMISSIVE;" -uroot -proot + mysql -e "SET @@GLOBAL.GTID_MODE = ON;" -uroot -proot + mysql -e "PURGE BINARY LOGS BEFORE now();" -uroot -proot displayName: Install MySql - bash: | curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain $(RUST_TOOLCHAIN) @@ -43,40 +49,6 @@ jobs: DATABASE_URL: mysql://root:root@127.0.0.1:3306/mysql displayName: Run tests - # - job: "TestBasicMacOs" - # pool: - # vmImage: "macOS-10.15" - # strategy: - # maxParallel: 10 - # matrix: - # stable: - # RUST_TOOLCHAIN: stable - # steps: - # - bash: | - # brew update - # brew install mysql - # brew services start mysql - # brew services stop mysql - # sleep 3 - # echo 'local_infile=1' >> /usr/local/etc/my.cnf - # echo 'socket=/tmp/mysql.sock' >> /usr/local/etc/my.cnf - # brew services start mysql - # sleep 5 - # /usr/local/Cellar/mysql/*/bin/mysql -e "SET GLOBAL max_allowed_packet = 36700160;" -uroot - # displayName: Install MySql - # - bash: | - # curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain $RUST_TOOLCHAIN - # displayName: Install rust (MacOs) - # - bash: | - # SSL=false COMPRESS=false cargo test - # SSL=true COMPRESS=false cargo test - # SSL=false COMPRESS=true cargo test - # SSL=true COMPRESS=true cargo test - # env: - # RUST_BACKTRACE: 1 - # DATABASE_URL: mysql://root@127.0.0.1/mysql - # displayName: Run tests - - job: "TestBasicWindows" pool: vmImage: "vs2017-win2016" @@ -95,6 +67,11 @@ jobs: call "C:\Program Files (x86)\MySQL\MySQL Installer for Windows\MySQLInstallerConsole.exe" community install server;8.0.11;x64:*:port=3306;rootpasswd=password;servicename=MySQL -silent netsh advfirewall firewall add rule name="Allow mysql" dir=in action=allow edge=yes remoteip=any protocol=TCP localport=80,8080,3306 "C:\Program Files\MySQL\MySQL Server 8.0\bin\mysql" -e "SET GLOBAL max_allowed_packet = 36700160;" -uroot -ppassword + "C:\Program Files\MySQL\MySQL Server 8.0\bin\mysql" -e "SET @@GLOBAL.ENFORCE_GTID_CONSISTENCY = WARN;" -uroot -ppassword + "C:\Program Files\MySQL\MySQL Server 8.0\bin\mysql" -e "SET @@GLOBAL.ENFORCE_GTID_CONSISTENCY = ON;" -uroot -ppassword + "C:\Program Files\MySQL\MySQL Server 8.0\bin\mysql" -e "SET @@GLOBAL.GTID_MODE = OFF_PERMISSIVE;" -uroot -ppassword + "C:\Program Files\MySQL\MySQL Server 8.0\bin\mysql" -e "SET @@GLOBAL.GTID_MODE = ON_PERMISSIVE;" -uroot -ppassword + "C:\Program Files\MySQL\MySQL Server 8.0\bin\mysql" -e "SET @@GLOBAL.GTID_MODE = ON;" -uroot -ppassword displayName: Install MySql - bash: | rustup install $RUST_TOOLCHAIN @@ -130,16 +107,20 @@ jobs: docker --version displayName: Install docker - bash: | - docker run --rm --name container -v `pwd`:/root -p 3307:3306 -d -e MYSQL_ROOT_PASSWORD=password mysql:$(DB_VERSION) --max-allowed-packet=36700160 --local-infile + if [[ "5.6" == "$(DB_VERSION)" ]]; then ARG="--secure-auth=OFF"; fi + docker run -d --name container -v `pwd`:/root -p 3307:3306 -e MYSQL_ROOT_PASSWORD=password mysql:$(DB_VERSION) --max-allowed-packet=36700160 --local-infile --log-bin=mysql-bin --log-slave-updates --gtid_mode=ON --enforce_gtid_consistency=ON --server-id=1 $ARG while ! nc -W 1 localhost 3307 | grep -q -P '.+'; do sleep 1; done displayName: Run MySql in Docker + - bash: | + docker exec container bash -l -c "mysql -uroot -ppassword -e \"SET old_passwords = 1; GRANT ALL PRIVILEGES ON *.* TO 'root2'@'%' IDENTIFIED WITH mysql_old_password AS 'password'; SET PASSWORD FOR 'root2'@'%' = OLD_PASSWORD('password')\""; + condition: eq(variables['DB_VERSION'], '5.6') - bash: | docker exec container bash -l -c "apt-get update" docker exec container bash -l -c "apt-get install -y curl clang libssl-dev pkg-config" docker exec container bash -l -c "curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable" displayName: Install Rust in docker - bash: | - if [[ "5.6" != "$(DB_VERSION)" ]]; then SSL=true; fi + if [[ "5.6" != "$(DB_VERSION)" ]]; then SSL=true; else DATABASE_URL="mysql://root2:password@127.0.0.1/mysql?secure_auth=false"; fi docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL cargo test" docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL COMPRESS=true cargo test" docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=$SSL cargo test" @@ -186,10 +167,11 @@ jobs: --max-allowed-packet=36700160 \ --local-infile \ --performance-schema=on \ + --log-bin=mysql-bin --gtid-domain-id=1 --server-id=1 \ --ssl \ --ssl-ca=/root/rust-mysql-simple/tests/ca-cert.pem \ --ssl-cert=/root/rust-mysql-simple/tests/server-cert.pem \ - --ssl-key=/root/rust-mysql-simple/tests/server-key.pem + --ssl-key=/root/rust-mysql-simple/tests/server-key.pem & while ! nc -W 1 localhost 3307 | grep -q -P '.+'; do sleep 1; done displayName: Run MariaDb in Docker - bash: | diff --git a/src/buffer_pool.rs b/src/buffer_pool.rs new file mode 100644 index 00000000..14116688 --- /dev/null +++ b/src/buffer_pool.rs @@ -0,0 +1,103 @@ +// Copyright (c) 2021 Anatoly Ikorsky +// +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , at your +// option. All files in the project carrying such notice may not be copied, +// modified, or distributed except according to those terms. + +use std::{ + mem::replace, + ops::Deref, + sync::{Arc, Mutex}, +}; + +#[derive(Debug)] +pub struct BufferPool { + pool_cap: usize, + buffer_cap: usize, + pool: Mutex>>, +} + +impl BufferPool { + pub fn new() -> Self { + let pool_cap = std::env::var("MYSQL_ASYNC_BUFFER_POOL_CAP") + .ok() + .and_then(|x| x.parse().ok()) + .unwrap_or(128_usize); + + let buffer_cap = std::env::var("MYSQL_ASYNC_BUFFER_SIZE_CAP") + .ok() + .and_then(|x| x.parse().ok()) + .unwrap_or(4 * 1024 * 1024); + + Self { + pool: Default::default(), + pool_cap, + buffer_cap, + } + } + + pub fn get(self: &Arc) -> PooledBuf { + let mut buf = self.pool.lock().unwrap().pop().unwrap_or_default(); + + // SAFETY: + // 1. OK – 0 is always within capacity + // 2. OK - nothing to initialize + unsafe { buf.set_len(0) } + + PooledBuf(buf, self.clone()) + } + + pub fn get_with>(self: &Arc, content: T) -> PooledBuf { + let mut buf = self.get(); + buf.as_mut().extend_from_slice(content.as_ref()); + buf + } + + fn put(self: &Arc, mut buf: Vec) { + if buf.len() > self.buffer_cap { + // TODO: until `Vec::shrink_to` stabilization + + // SAFETY: + // 1. OK – new_len <= capacity + // 2. OK - 0..new_len is initialized + unsafe { buf.set_len(self.buffer_cap) } + buf.shrink_to_fit(); + } + + let mut pool = self.pool.lock().unwrap(); + if pool.len() < self.pool_cap { + pool.push(buf); + } + } +} + +impl Default for BufferPool { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +pub struct PooledBuf(Vec, Arc); + +impl AsMut> for PooledBuf { + fn as_mut(&mut self) -> &mut Vec { + &mut self.0 + } +} + +impl Deref for PooledBuf { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +impl Drop for PooledBuf { + fn drop(&mut self) { + self.1.put(replace(&mut self.0, vec![])) + } +} diff --git a/src/conn/binlog_stream.rs b/src/conn/binlog_stream.rs new file mode 100644 index 00000000..f6ffab13 --- /dev/null +++ b/src/conn/binlog_stream.rs @@ -0,0 +1,96 @@ +// Copyright (c) 2020 Anatoly Ikorsky +// +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , at your +// option. All files in the project carrying such notice may not be copied, +// modified, or distributed except according to those terms. + +use futures_core::ready; +use mysql_common::{ + binlog::{ + consts::BinlogVersion::Version4, + events::{Event, TableMapEvent}, + EventStreamReader, + }, + io::ParseBuf, + packets::{ErrPacket, NetworkStreamTerminator, OkPacketDeserializer}, +}; + +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{error::DriverError, io::ReadPacket, Conn, Result}; + +/// Binlog event stream. +/// +/// Stream initialization is lazy, i.e. binlog won't be requested until this stream is polled. +pub struct BinlogStream { + read_packet: ReadPacket<'static, 'static>, + esr: EventStreamReader, +} + +impl BinlogStream { + /// `conn` is a `Conn` with `request_binlog` executed on it. + pub(super) fn new(conn: Conn) -> Self { + BinlogStream { + read_packet: ReadPacket::new(conn), + esr: EventStreamReader::new(Version4), + } + } + + /// Returns a table map event for the given table id. + pub fn get_tme(&self, table_id: u64) -> Option<&TableMapEvent<'static>> { + self.esr.get_tme(table_id) + } +} + +impl futures_core::stream::Stream for BinlogStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let packet = match ready!(Pin::new(&mut self.read_packet).poll(cx)) { + Ok(packet) => packet, + Err(err) => return Poll::Ready(Some(Err(err.into()))), + }; + + let first_byte = packet.get(0).copied(); + + if first_byte == Some(255) { + if let Ok(ErrPacket::Error(err)) = + ParseBuf(&*packet).parse(self.read_packet.conn_ref().capabilities()) + { + return Poll::Ready(Some(Err(From::from(err)))); + } + } + + if first_byte == Some(254) && packet.len() < 8 { + if ParseBuf(&*packet) + .parse::>( + self.read_packet.conn_ref().capabilities(), + ) + .is_ok() + { + return Poll::Ready(None); + } + } + + if first_byte == Some(0) { + let event_data = &packet[1..]; + match self.esr.read(event_data) { + Ok(event) => { + return Poll::Ready(Some(Ok(event))); + } + Err(err) => return Poll::Ready(Some(Err(err.into()))), + } + } else { + return Poll::Ready(Some(Err(DriverError::UnexpectedPacket { + payload: packet.to_vec(), + } + .into()))); + } + } +} diff --git a/src/conn/mod.rs b/src/conn/mod.rs index e39f4c03..9488fd5a 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -6,16 +6,19 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. +use futures_util::FutureExt; pub use mysql_common::named_params; use mysql_common::{ - constants::DEFAULT_MAX_ALLOWED_PACKET, + constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8_GENERAL_CI}, crypto, + io::ParseBuf, packets::{ - parse_auth_switch_request, parse_err_packet, parse_handshake_packet, parse_ok_packet, - AuthPlugin, AuthSwitchRequest, ErrPacket, HandshakeResponse, OkPacket, OkPacketKind, - SslRequest, + binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, + HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, + ResultSetTerminator, SslRequest, }, + proto::MySerialize, }; use std::{ @@ -29,6 +32,7 @@ use std::{ }; use crate::{ + buffer_pool::PooledBuf, conn::{pool::Pool, stmt_cache::StmtCache}, consts::{CapabilityFlags, Command, StatusFlags}, error::*, @@ -39,11 +43,12 @@ use crate::{ transaction::TxStatus, BinaryProtocol, Queryable, TextProtocol, }, - OptsBuilder, + BinlogStream, OptsBuilder, }; use self::routines::Routine; +pub mod binlog_stream; pub mod pool; pub mod routines; pub mod stmt_cache; @@ -77,12 +82,13 @@ fn disconnect(mut conn: Conn) { struct ConnInner { stream: Option, id: u32, + is_mariadb: bool, version: (u16, u16, u16), socket: Option, capabilities: CapabilityFlags, status: StatusFlags, last_ok_packet: Option>, - last_err_packet: Option>, + last_err_packet: Option>, pool: Option, pending_result: Option, tx_status: TxStatus, @@ -94,7 +100,7 @@ struct ConnInner { auth_plugin: AuthPlugin<'static>, auth_switched: bool, /// Connection is already disconnected. - disconnected: bool, + pub(crate) disconnected: bool, } impl fmt::Debug for ConnInner { @@ -120,6 +126,7 @@ impl ConnInner { last_ok_packet: None, last_err_packet: None, stream: None, + is_mariadb: false, version: (0, 0, 0), id: 0, pending_result: None, @@ -197,6 +204,11 @@ impl Conn { .unwrap_or_default() } + /// Returns a reference to the last OK packet. + pub fn last_ok_packet(&self) -> Option<&OkPacket<'static>> { + self.inner.last_ok_packet.as_ref() + } + pub(crate) fn stream_mut(&mut self) -> Result<&mut Stream> { self.inner.stream_mut() } @@ -232,10 +244,16 @@ impl Conn { } /// Handles ERR packet. - pub(crate) fn handle_err(&mut self, err_packet: ErrPacket<'static>) { - self.inner.status = StatusFlags::empty(); - self.inner.last_ok_packet = None; - self.inner.last_err_packet = Some(err_packet); + pub(crate) fn handle_err(&mut self, err_packet: ErrPacket<'_>) -> Result<()> { + match err_packet { + ErrPacket::Error(err) => { + self.inner.status = StatusFlags::empty(); + self.inner.last_ok_packet = None; + self.inner.last_err_packet = Some(err.clone().into_owned()); + Err(Error::from(err)) + } + ErrPacket::Progress(_) => Ok(()), + } } /// Returns the current transaction status. @@ -353,7 +371,7 @@ impl Conn { async fn handle_handshake(&mut self) -> Result<()> { let packet = self.read_packet().await?; - let handshake = parse_handshake_packet(&*packet)?; + let handshake = ParseBuf(&*packet).parse::(())?; self.inner.nonce = { let mut nonce = Vec::from(handshake.scramble_1_ref()); nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..])); @@ -363,12 +381,18 @@ impl Conn { self.inner.capabilities = handshake.capabilities() & self.inner.opts.get_capabilities(); self.inner.version = handshake .maria_db_server_version_parsed() + .map(|version| { + self.inner.is_mariadb = true; + version + }) .or_else(|| handshake.server_version_parsed()) .unwrap_or((0, 0, 0)); self.inner.id = handshake.connection_id(); self.inner.status = handshake.status_flags(); self.inner.auth_plugin = match handshake.auth_plugin() { - Some(AuthPlugin::MysqlNativePassword) => AuthPlugin::MysqlNativePassword, + Some(AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword) => { + AuthPlugin::MysqlNativePassword + } Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password, Some(AuthPlugin::Other(ref name)) => { let name = String::from_utf8_lossy(name).into(); @@ -386,8 +410,12 @@ impl Conn { .get_capabilities() .contains(CapabilityFlags::CLIENT_SSL) { - let ssl_request = SslRequest::new(self.inner.capabilities); - self.write_packet(ssl_request.as_ref()).await?; + let ssl_request = SslRequest::new( + self.inner.capabilities, + DEFAULT_MAX_ALLOWED_PACKET as u32, + UTF8_GENERAL_CI as u8, + ); + self.write_struct(&ssl_request).await?; let conn = self; let ssl_opts = conn.opts().ssl_opts().cloned().expect("unreachable"); let domain = conn.opts().ip_or_hostname().into(); @@ -405,16 +433,20 @@ impl Conn { .gen_data(self.inner.opts.pass(), &*self.inner.nonce); let handshake_response = HandshakeResponse::new( - &auth_data, + auth_data.as_deref(), self.inner.version, - self.inner.opts.user(), - self.inner.opts.db_name(), - &self.inner.auth_plugin, + self.inner.opts.user().map(|x| x.as_bytes()), + self.inner.opts.db_name().map(|x| x.as_bytes()), + Some(self.inner.auth_plugin.borrow()), self.capabilities(), - &Default::default(), // TODO: Add support + Default::default(), // TODO: Add support ); - self.write_packet(handshake_response.as_ref()).await?; + // Serialize here to satisfy borrow checker. + let mut buf = crate::BUFFER_POOL.get(); + handshake_response.serialize(buf.as_mut()); + + self.write_packet(buf).await?; Ok(()) } @@ -424,15 +456,31 @@ impl Conn { ) -> Result<()> { if !self.inner.auth_switched { self.inner.auth_switched = true; - self.inner.nonce = auth_switch_request.plugin_data().into(); + + if matches!( + auth_switch_request.auth_plugin(), + AuthPlugin::MysqlOldPassword + ) { + if self.inner.opts.secure_auth() { + return Err(DriverError::MysqlOldPasswordDisabled.into()); + } + } + self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned(); + let plugin_data = self .inner .auth_plugin - .gen_data(self.inner.opts.pass(), &*self.inner.nonce) - .unwrap_or_else(Vec::new); - self.write_packet(plugin_data).await?; + .gen_data(self.inner.opts.pass(), &*self.inner.nonce); + + if let Some(plugin_data) = plugin_data { + self.write_struct(&plugin_data).await?; + } else { + self.write_packet(crate::BUFFER_POOL.get()).await?; + } + self.continue_auth().await?; + Ok(()) } else { unreachable!("auth_switched flag should be checked by caller") @@ -444,7 +492,7 @@ impl Conn { // see https://github.com/rust-lang/rust/issues/46415#issuecomment-528099782 Box::pin(async move { match self.inner.auth_plugin { - AuthPlugin::MysqlNativePassword => { + AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword => { self.continue_mysql_native_password_auth().await?; Ok(()) } @@ -487,32 +535,39 @@ impl Conn { self.drop_packet().await } Some(0x04) => { - let mut pass = self.inner.opts.pass().map(Vec::from).unwrap_or_default(); - pass.push(0); + let pass = self.inner.opts.pass().unwrap_or_default(); + let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes()); + pass.as_mut().push(0); if self.is_secure() { - self.write_packet(&*pass).await?; + self.write_packet(pass).await?; } else { - self.write_packet(&[0x02][..]).await?; + self.write_bytes(&[0x02][..]).await?; let packet = self.read_packet().await?; let key = &packet[1..]; - for (i, byte) in pass.iter_mut().enumerate() { + for (i, byte) in pass.as_mut().iter_mut().enumerate() { *byte ^= self.inner.nonce[i % self.inner.nonce.len()]; } let encrypted_pass = crypto::encrypt(&*pass, key); - self.write_packet(&*encrypted_pass).await?; + self.write_bytes(&*encrypted_pass).await?; }; self.drop_packet().await?; Ok(()) } - _ => Err(DriverError::UnexpectedPacket { payload: packet }.into()), + _ => Err(DriverError::UnexpectedPacket { + payload: packet.to_vec(), + } + .into()), }, Some(0xfe) if !self.inner.auth_switched => { - let auth_switch_request = parse_auth_switch_request(&*packet)?.into_owned(); + let auth_switch_request = ParseBuf(&*packet).parse::(())?; self.perform_auth_switch(auth_switch_request).await?; Ok(()) } - _ => Err(DriverError::UnexpectedPacket { payload: packet }.into()), + _ => Err(DriverError::UnexpectedPacket { + payload: packet.to_vec(), + } + .into()), } } @@ -521,45 +576,70 @@ impl Conn { match packet.get(0) { Some(0x00) => Ok(()), Some(0xfe) if !self.inner.auth_switched => { - let auth_switch_request = parse_auth_switch_request(packet.as_ref())?.into_owned(); - self.perform_auth_switch(auth_switch_request).await?; - Ok(()) + let auth_switch = if packet.len() > 1 { + ParseBuf(&*packet).parse(())? + } else { + let _ = ParseBuf(&*packet).parse::(())?; + // map OldAuthSwitch to AuthSwitch with mysql_old_password plugin + AuthSwitchRequest::new( + "mysql_old_password".as_bytes(), + self.inner.nonce.clone(), + ) + }; + self.perform_auth_switch(auth_switch).await } - _ => Err(DriverError::UnexpectedPacket { payload: packet }.into()), + _ => Err(DriverError::UnexpectedPacket { + payload: packet.to_vec(), + } + .into()), } } - fn handle_packet(&mut self, packet: &[u8]) -> Result<()> { - let kind = if self.get_pending_result().is_some() { - OkPacketKind::ResultSetTerminator + /// Returns `true` for ProgressReport packet. + fn handle_packet(&mut self, packet: &PooledBuf) -> Result { + let ok_packet = if self.get_pending_result().is_some() { + ParseBuf(&*packet) + .parse::>(self.capabilities()) + .map(|x| x.into_inner()) } else { - OkPacketKind::Other + ParseBuf(&*packet) + .parse::>(self.capabilities()) + .map(|x| x.into_inner()) }; - if let Ok(ok_packet) = parse_ok_packet(&*packet, self.capabilities(), kind) { + if let Ok(ok_packet) = ok_packet { self.handle_ok(ok_packet.into_owned()); - } else if let Ok(err_packet) = parse_err_packet(&*packet, self.capabilities()) { - self.handle_err(err_packet.clone().into_owned()); - return Err(err_packet.into()); + } else { + let err_packet = ParseBuf(&*packet).parse::(self.capabilities()); + if let Ok(err_packet) = err_packet { + self.handle_err(err_packet)?; + return Ok(true); + } } - Ok(()) + Ok(false) } - pub(crate) async fn read_packet(&mut self) -> Result> { - let packet = crate::io::ReadPacket::new(&mut *self) - .await - .map_err(|io_err| { - self.inner.stream.take(); - self.inner.disconnected = true; - Error::from(io_err) - })?; - self.handle_packet(&*packet)?; - Ok(packet) + pub(crate) async fn read_packet(&mut self) -> Result { + loop { + let packet = crate::io::ReadPacket::new(&mut *self) + .await + .map_err(|io_err| { + self.inner.stream.take(); + self.inner.disconnected = true; + Error::from(io_err) + })?; + if self.handle_packet(&packet)? { + // ignore progress report + continue; + } else { + return Ok(packet); + } + } } /// Returns future that reads packets from a server. - pub(crate) async fn read_packets(&mut self, n: usize) -> Result>> { + pub(crate) async fn read_packets(&mut self, n: usize) -> Result> { let mut packets = Vec::with_capacity(n); for _ in 0..n { packets.push(self.read_packet().await?); @@ -567,11 +647,8 @@ impl Conn { Ok(packets) } - pub(crate) async fn write_packet(&mut self, data: T) -> Result<()> - where - T: Into>, - { - crate::io::WritePacket::new(&mut *self, data.into()) + pub(crate) async fn write_packet(&mut self, data: PooledBuf) -> Result<()> { + crate::io::WritePacket::new(&mut *self, data) .await .map_err(|io_err| { self.inner.stream.take(); @@ -580,8 +657,28 @@ impl Conn { }) } + /// Writes bytes to a server. + pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> { + let buf = crate::BUFFER_POOL.get_with(bytes); + self.write_packet(buf).await + } + + /// Sends a serializable structure to a server. + pub(crate) async fn write_struct(&mut self, x: &T) -> Result<()> { + let mut buf = crate::BUFFER_POOL.get(); + x.serialize(buf.as_mut()); + self.write_packet(buf).await + } + + /// Sends a command to a server. + pub(crate) async fn write_command(&mut self, cmd: &T) -> Result<()> { + self.clean_dirty().await?; + self.reset_seq_id(); + self.write_struct(cmd).await + } + /// Returns future that sends full command body to a server. - pub(crate) async fn write_command_raw(&mut self, body: Vec) -> Result<()> { + pub(crate) async fn write_command_raw(&mut self, body: PooledBuf) -> Result<()> { debug_assert!(!body.is_empty()); self.clean_dirty().await?; self.reset_seq_id(); @@ -594,10 +691,11 @@ impl Conn { T: AsRef<[u8]>, { let cmd_data = cmd_data.as_ref(); - let mut body = Vec::with_capacity(1 + cmd_data.len()); + let mut buf = crate::BUFFER_POOL.get(); + let body = buf.as_mut(); body.push(cmd as u8); body.extend_from_slice(cmd_data); - self.write_command_raw(body).await + self.write_command_raw(buf).await } async fn drop_packet(&mut self) -> Result<()> { @@ -618,7 +716,7 @@ impl Conn { /// Returns a future that resolves to [`Conn`]. pub fn new>(opts: T) -> crate::BoxFuture<'static, Conn> { let opts = opts.into(); - let fut = Box::pin(async move { + async move { let mut conn = Conn::empty(opts.clone()); let stream = if let Some(_path) = opts.socket() { @@ -649,8 +747,8 @@ impl Conn { conn.run_init_commands().await?; Ok(conn) - }); - crate::BoxFuture(fut) + } + .boxed() } /// Returns a future that resolves to [`Conn`]. @@ -689,18 +787,25 @@ impl Conn { /// Reads and stores `max_allowed_packet` in the connection. async fn read_max_allowed_packet(&mut self) -> Result<()> { - let row_opt = self.query_first("SELECT @@max_allowed_packet").await?; + let max_allowed_packet = if let Some(value) = self.opts().max_allowed_packet() { + Some(value) + } else { + self.query_first("SELECT @@max_allowed_packet").await? + }; if let Some(stream) = self.inner.stream.as_mut() { - stream.set_max_allowed_packet(row_opt.unwrap_or((DEFAULT_MAX_ALLOWED_PACKET,)).0); + stream.set_max_allowed_packet(max_allowed_packet.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET)); } Ok(()) } /// Reads and stores `wait_timeout` in the connection. async fn read_wait_timeout(&mut self) -> Result<()> { - let row_opt = self.query_first("SELECT @@wait_timeout").await?; - let wait_timeout_secs = row_opt.unwrap_or((28800,)).0; - self.inner.wait_timeout = Duration::from_secs(wait_timeout_secs); + let wait_timeout = if let Some(value) = self.opts().wait_timeout() { + Some(value) + } else { + self.query_first("SELECT @@wait_timeout").await? + }; + self.inner.wait_timeout = Duration::from_secs(wait_timeout.unwrap_or(28800) as u64); Ok(()) } @@ -726,7 +831,14 @@ impl Conn { pub async fn reset(&mut self) -> Result<()> { let pool = self.inner.pool.clone(); - if self.inner.version > (5, 7, 2) { + let supports_com_reset_connection = if self.inner.is_mariadb { + self.inner.version >= (10, 2, 4) + } else { + // assuming mysql + self.inner.version > (5, 7, 2) + }; + + if supports_com_reset_connection { self.routine(routines::ResetRoutine).await?; } else { let opts = self.inner.opts.clone(); @@ -802,15 +914,184 @@ impl Conn { } Ok(self) } + + async fn register_as_slave(&mut self, server_id: u32) -> Result<()> { + use mysql_common::packets::ComRegisterSlave; + + self.query_drop("SET @master_binlog_checksum='ALL'").await?; + self.write_command(&ComRegisterSlave::new(server_id)) + .await?; + + // Server will respond with OK. + self.read_packet().await?; + + Ok(()) + } + + async fn request_binlog(&mut self, request: BinlogRequest<'_>) -> Result<()> { + self.register_as_slave(request.server_id()).await?; + self.write_command(&request.as_cmd()).await?; + Ok(()) + } + + pub async fn get_binlog_stream(mut self, request: BinlogRequest<'_>) -> Result { + // We'll disconnect this connection from a pool before requesting the binlog. + self.inner.pool = None; + self.request_binlog(request).await?; + + Ok(BinlogStream::new(self)) + } } #[cfg(test)] mod test { + use futures_util::stream::StreamExt; + use mysql_common::binlog::events::EventData; + use tokio::time::timeout; + + use std::time::Duration; + use crate::{ - from_row, params, prelude::*, test_misc::get_opts, Conn, Error, OptsBuilder, Pool, TxOpts, - WhiteListFsLocalInfileHandler, + from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest, Conn, + Error, OptsBuilder, Pool, TxOpts, WhiteListFsLocalInfileHandler, }; + async fn gen_dummy_data() -> super::Result<()> { + let mut conn = Conn::new(get_opts()).await?; + + "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)" + .ignore(&mut conn) + .await?; + + for i in 0_u8..100 { + "INSERT INTO customers(customer_id) VALUES (?)" + .with((i,)) + .ignore(&mut conn) + .await?; + } + + "DROP TABLE customers".ignore(&mut conn).await?; + + Ok(()) + } + + #[tokio::test] + async fn should_read_binlog() -> super::Result<()> { + async fn get_conn() -> super::Result<(Conn, Vec, u64)> { + let mut conn = Conn::new(get_opts()).await?; + + if let Ok(Some(gtid_mode)) = "SELECT @@GLOBAL.GTID_MODE" + .first::(&mut conn) + .await + { + if !gtid_mode.starts_with("ON") { + panic!( + "GTID_MODE is disabled \ + (enable using --gtid_mode=ON --enforce_gtid_consistency=ON)" + ); + } + } + + let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn).await?.unwrap(); + let filename = row.get(0).unwrap(); + let position = row.get(1).unwrap(); + + gen_dummy_data().await.unwrap(); + Ok((conn, filename, position)) + } + + // iterate using COM_BINLOG_DUMP + let (conn, filename, pos) = get_conn().await.unwrap(); + let is_mariadb = conn.inner.is_mariadb; + + let mut binlog_stream = conn + .get_binlog_stream(BinlogRequest::new(12).with_filename(filename).with_pos(pos)) + .await + .unwrap(); + + let mut events_num = 0; + while let Ok(Some(event)) = timeout(Duration::from_secs(1), binlog_stream.next()).await { + let event = event.unwrap(); + events_num += 1; + + // assert that event type is known + event.header().event_type().unwrap(); + + // iterate over rows of an event + match event.read_data()?.unwrap() { + EventData::RowsEvent(re) => { + let tme = binlog_stream.get_tme(re.table_id()); + for row in re.rows(tme.unwrap()) { + row.unwrap(); + } + } + _ => (), + } + } + assert!(events_num > 0); + + if !is_mariadb { + // iterate using COM_BINLOG_DUMP_GTID + let (conn, filename, pos) = get_conn().await.unwrap(); + + let mut binlog_stream = conn + .get_binlog_stream( + BinlogRequest::new(13) + .with_use_gtid(true) + .with_filename(filename) + .with_pos(pos), + ) + .await + .unwrap(); + + events_num = 0; + while let Ok(Some(event)) = timeout(Duration::from_secs(1), binlog_stream.next()).await + { + let event = event.unwrap(); + events_num += 1; + + // assert that event type is known + event.header().event_type().unwrap(); + + // iterate over rows of an event + match event.read_data()?.unwrap() { + EventData::RowsEvent(re) => { + let tme = binlog_stream.get_tme(re.table_id()); + for row in re.rows(tme.unwrap()) { + row.unwrap(); + } + } + _ => (), + } + } + assert!(events_num > 0); + } + + // iterate using COM_BINLOG_DUMP with BINLOG_DUMP_NON_BLOCK flag + let (conn, filename, pos) = get_conn().await.unwrap(); + + let mut binlog_stream = conn + .get_binlog_stream( + BinlogRequest::new(14) + .with_filename(filename) + .with_pos(pos) + .with_flags(BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK), + ) + .await + .unwrap(); + + events_num = 0; + while let Some(event) = binlog_stream.next().await { + let event = event.unwrap(); + events_num += 1; + event.header().event_type().unwrap(); + event.read_data()?; + } + assert!(events_num > 0); + + Ok(()) + } + #[test] fn opts_should_satisfy_send_and_sync() { struct A(T); diff --git a/src/conn/pool/futures/get_conn.rs b/src/conn/pool/futures/get_conn.rs index b7e1cf0b..429a016a 100644 --- a/src/conn/pool/futures/get_conn.rs +++ b/src/conn/pool/futures/get_conn.rs @@ -6,28 +6,45 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. -use futures_core::ready; use std::{ + fmt, future::Future, pin::Pin, task::{Context, Poll}, }; +use futures_core::ready; + use crate::{ conn::{pool::Pool, Conn}, error::*, - BoxFuture, }; /// States of the GetConn future. -#[derive(Debug)] pub(crate) enum GetConnInner { New, Done, // TODO: one day this should be an existential - Connecting(BoxFuture<'static, Conn>), + Connecting(crate::BoxFuture<'static, Conn>), /// This future will check, that idling connection is alive. - Checking(BoxFuture<'static, Conn>), + Checking(crate::BoxFuture<'static, Conn>), +} + +impl fmt::Debug for GetConnInner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + GetConnInner::New => f.debug_tuple("GetConnInner::New").finish(), + GetConnInner::Done => f.debug_tuple("GetConnInner::Done").finish(), + GetConnInner::Connecting(_) => f + .debug_tuple("GetConnInner::Connecting") + .field(&"") + .finish(), + GetConnInner::Checking(_) => f + .debug_tuple("GetConnInner::Checking") + .field(&"") + .finish(), + } + } } impl GetConnInner { diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 93174730..74c94e8f 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -6,10 +6,12 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. +use futures_util::FutureExt; use tokio::sync::mpsc; use std::{ collections::VecDeque, + convert::TryFrom, pin::Pin, str::FromStr, sync::{atomic, Arc, Mutex}, @@ -22,7 +24,6 @@ use crate::{ error::*, opts::{Opts, PoolOpts}, queryable::transaction::{Transaction, TxOpts, TxStatus}, - BoxFuture, }; mod recycler; @@ -111,8 +112,16 @@ pub struct Pool { impl Pool { /// Creates a new pool of connections. - pub fn new>(opts: O) -> Pool { - let opts = opts.into(); + /// + /// # Panic + /// + /// It'll panic if `Opts::try_from(opts)` returns error. + pub fn new(opts: O) -> Pool + where + Opts: TryFrom, + >::Error: std::error::Error, + { + let opts = Opts::try_from(opts).unwrap(); let pool_opts = opts.pool_opts().clone(); let (tx, rx) = mpsc::unbounded_channel(); Pool { @@ -242,10 +251,13 @@ impl Pool { if !conn.expired() { return Poll::Ready(Ok(GetConn { pool: Some(self.clone()), - inner: GetConnInner::Checking(BoxFuture(Box::pin(async move { - conn.stream_mut()?.check().await?; - Ok(conn) - }))), + inner: GetConnInner::Checking( + async move { + conn.stream_mut()?.check().await?; + Ok(conn) + } + .boxed(), + ), })); } else { self.send_to_recycler(conn); @@ -260,7 +272,7 @@ impl Pool { return Poll::Ready(Ok(GetConn { pool: Some(self.clone()), - inner: GetConnInner::Connecting(BoxFuture(Box::pin(Conn::new(self.opts.clone())))), + inner: GetConnInner::Connecting(Conn::new(self.opts.clone()).boxed()), })); } @@ -358,11 +370,12 @@ mod test { // create some conns.. let connections = (0..NUM_CONNS).map(|_| { - crate::BoxFuture(Box::pin(async { + async { let mut conn = pool.get_conn().await?; conn.ping().await?; - Ok(conn) - })) + crate::Result::Ok(conn) + } + .boxed() }); // collect ids.. @@ -573,7 +586,7 @@ mod test { #[tokio::test] async fn should_hold_bounds_on_error() -> super::Result<()> { // Should not be possible to connect to broadcast address. - let pool = Pool::new(String::from("mysql://255.255.255.255")); + let pool = Pool::new("mysql://255.255.255.255"); assert!(try_join!(pool.get_conn(), pool.get_conn()).is_err()); assert_eq!(ex_field!(pool, exist), 0); @@ -651,7 +664,7 @@ mod test { assert_eq!( *result.unwrap_err().downcast::<&str>().unwrap(), - PANIC_MESSAGE, + "ORIGINAL_PANIC", ); } diff --git a/src/conn/pool/recycler.rs b/src/conn/pool/recycler.rs index da3104d6..b8fae75e 100644 --- a/src/conn/pool/recycler.rs +++ b/src/conn/pool/recycler.rs @@ -7,7 +7,7 @@ // modified, or distributed except according to those terms. use futures_core::stream::Stream; -use futures_util::stream::futures_unordered::FuturesUnordered; +use futures_util::{stream::futures_unordered::FuturesUnordered, FutureExt}; use tokio::sync::mpsc; use std::{ @@ -64,22 +64,18 @@ impl Future for Recycler { ($self:ident, $conn:ident) => { if $conn.inner.stream.is_none() || $conn.inner.disconnected { // drop unestablished connection - $self - .discard - .push(BoxFuture(Box::pin(::futures_util::future::ok(())))); + $self.discard.push(futures_util::future::ok(()).boxed()); } else if $conn.inner.tx_status != TxStatus::None || $conn.inner.pending_result.is_some() { - $self - .cleaning - .push(BoxFuture(Box::pin($conn.cleanup_for_pool()))); + $self.cleaning.push($conn.cleanup_for_pool().boxed()); } else if $conn.expired() || close { - $self.discard.push(BoxFuture(Box::pin($conn.close_conn()))); + $self.discard.push($conn.close_conn().boxed()); } else { let mut exchange = $self.inner.exchange.lock().unwrap(); if exchange.available.len() >= $self.pool_opts.active_bound() { drop(exchange); - $self.discard.push(BoxFuture(Box::pin($conn.close_conn()))); + $self.discard.push($conn.close_conn().boxed()); } else { exchange.available.push_back($conn.into()); if let Some(w) = exchange.waiting.pop_front() { diff --git a/src/conn/routines/exec.rs b/src/conn/routines/exec.rs index aa89b158..4061065e 100644 --- a/src/conn/routines/exec.rs +++ b/src/conn/routines/exec.rs @@ -41,7 +41,7 @@ impl Routine<()> for ExecRoutine<'_> { conn.send_long_data(self.stmt.id(), params.iter()).await?; } - conn.write_command_raw(body).await?; + conn.write_command(&body).await?; conn.read_result_set::(true).await?; break; } @@ -69,7 +69,7 @@ impl Routine<()> for ExecRoutine<'_> { let (body, _) = ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(&[]); - conn.write_command_raw(body).await?; + conn.write_command(&body).await?; conn.read_result_set::(true).await?; break; } diff --git a/src/conn/routines/helpers.rs b/src/conn/routines/helpers.rs index 2ee9b2e6..16a39a01 100644 --- a/src/conn/routines/helpers.rs +++ b/src/conn/routines/helpers.rs @@ -4,8 +4,8 @@ use std::sync::Arc; use mysql_common::{ constants::MAX_PAYLOAD_LEN, - io::ReadMysqlExt, - packets::{parse_local_infile_packet, ComStmtSendLongData}, + io::{ParseBuf, ReadMysqlExt}, + packets::{ComStmtSendLongData, LocalInfilePacket}, value::Value, }; use tokio::io::AsyncReadExt; @@ -34,8 +34,8 @@ impl Conn { None }); for chunk in chunks { - let com = ComStmtSendLongData::new(statement_id, i, chunk); - self.write_command_raw(com.into()).await?; + let com = ComStmtSendLongData::new(statement_id, i as u16, chunk); + self.write_command(&com).await?; } } } @@ -86,7 +86,7 @@ impl Conn { where P: Protocol, { - let local_infile = parse_local_infile_packet(&*packet)?; + let local_infile = ParseBuf(packet).parse::(())?; let (local_infile, handler) = match self.opts().local_infile_handler() { Some(handler) => ((local_infile.into_owned(), handler)), None => return Err(DriverError::NoLocalInfileHandler.into()), @@ -96,7 +96,7 @@ impl Conn { let mut buf = [0; 4096]; loop { let read = reader.read(&mut buf[..]).await?; - self.write_packet(&buf[..read]).await?; + self.write_bytes(&buf[..read]).await?; if read == 0 { break; diff --git a/src/conn/routines/ping.rs b/src/conn/routines/ping.rs index e5abe0b8..f0c04ada 100644 --- a/src/conn/routines/ping.rs +++ b/src/conn/routines/ping.rs @@ -13,8 +13,7 @@ pub struct PingRoutine; impl Routine<()> for PingRoutine { fn call<'a>(&'a mut self, conn: &'a mut Conn) -> BoxFuture<'a, crate::Result<()>> { async move { - conn.write_command_raw(vec![Command::COM_PING as u8]) - .await?; + conn.write_command_data(Command::COM_PING, &[]).await?; conn.read_packet().await?; Ok(()) } diff --git a/src/connection_like/mod.rs b/src/connection_like/mod.rs index 064dc505..bb322187 100644 --- a/src/connection_like/mod.rs +++ b/src/connection_like/mod.rs @@ -6,6 +6,8 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. +use futures_util::FutureExt; + use crate::{BoxFuture, Pool}; /// Connection. @@ -79,10 +81,11 @@ impl<'a, 't: 'a, T: Into> + Send> ToConnection<'a, 't> for T impl<'a> ToConnection<'a, 'static> for &'a Pool { fn to_connection(self) -> ToConnectionResult<'a, 'static> { - let fut = BoxFuture(Box::pin(async move { + let fut = async move { let conn = self.get_conn().await?; Ok(conn.into()) - })); + } + .boxed(); ToConnectionResult::Mediate(fut) } } diff --git a/src/error.rs b/src/error.rs index 288c51a7..db361c17 100644 --- a/src/error.rs +++ b/src/error.rs @@ -9,7 +9,7 @@ pub use url::ParseError; use mysql_common::{ - named_params::MixedParamsError, packets::ErrPacket, params::MissingNamedParameterError, + named_params::MixedParamsError, params::MissingNamedParameterError, proto::codec::error::PacketCodecError, row::Row, value::Value, }; use thiserror::Error; @@ -152,6 +152,9 @@ pub enum DriverError { #[error("Named pipe connections temporary disabled (see tokio-rs/tokio#3118)")] NamedPipesDisabled, + + #[error("`mysql_old_password` plugin is insecure and disabled by default")] + MysqlOldPasswordDisabled, } impl From for Error { @@ -196,8 +199,8 @@ impl From for IoError { } } -impl From> for ServerError { - fn from(packet: ErrPacket<'_>) -> Self { +impl From> for ServerError { + fn from(packet: mysql_common::packets::ServerError<'_>) -> Self { ServerError { code: packet.error_code(), message: packet.message_str().into(), @@ -206,8 +209,8 @@ impl From> for ServerError { } } -impl From> for Error { - fn from(packet: ErrPacket<'_>) -> Self { +impl From> for Error { + fn from(packet: mysql_common::packets::ServerError<'_>) -> Self { Error::Server(packet.into()) } } diff --git a/src/io/mod.rs b/src/io/mod.rs index 4abbe79b..11e22be0 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -34,6 +34,7 @@ use std::{ ErrorKind::{BrokenPipe, NotConnected, Other}, Read, }, + mem::replace, net::{SocketAddr, ToSocketAddrs}, ops::{Deref, DerefMut}, pin::Pin, @@ -41,7 +42,7 @@ use std::{ time::Duration, }; -use crate::{error::IoError, opts::SslOpts}; +use crate::{buffer_pool::PooledBuf, error::IoError, opts::SslOpts}; #[cfg(unix)] use crate::io::socket::Socket; @@ -61,37 +62,54 @@ mod read_packet; mod socket; mod write_packet; -#[derive(Debug, Default)] -pub struct PacketCodec(PacketCodecInner); +#[derive(Debug)] +pub struct PacketCodec { + inner: PacketCodecInner, + decode_buf: PooledBuf, +} + +impl Default for PacketCodec { + fn default() -> Self { + Self { + inner: Default::default(), + decode_buf: crate::BUFFER_POOL.get(), + } + } +} impl Deref for PacketCodec { type Target = PacketCodecInner; fn deref(&self) -> &Self::Target { - &self.0 + &self.inner } } impl DerefMut for PacketCodec { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 + &mut self.inner } } impl Decoder for PacketCodec { - type Item = Vec; + type Item = PooledBuf; type Error = IoError; fn decode(&mut self, src: &mut BytesMut) -> std::result::Result, IoError> { - Ok(self.0.decode(src)?) + if self.inner.decode(src, self.decode_buf.as_mut())? { + let new_buf = crate::BUFFER_POOL.get(); + Ok(Some(replace(&mut self.decode_buf, new_buf))) + } else { + Ok(None) + } } } -impl Encoder> for PacketCodec { +impl Encoder for PacketCodec { type Error = IoError; - fn encode(&mut self, item: Vec, dst: &mut BytesMut) -> std::result::Result<(), IoError> { - Ok(self.0.encode(item, dst)?) + fn encode(&mut self, item: PooledBuf, dst: &mut BytesMut) -> std::result::Result<(), IoError> { + Ok(self.inner.encode(&mut item.as_ref(), dst)?) } } @@ -515,7 +533,7 @@ impl Stream { } impl stream::Stream for Stream { - type Item = std::result::Result, IoError>; + type Item = std::result::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if !self.closed { @@ -552,8 +570,8 @@ mod test { }; assert_eq!( - sock.keepalive().unwrap(), - Some(std::time::Duration::from_millis(42_000)), + sock.keepalive_time().unwrap(), + std::time::Duration::from_millis(42_000), ); std::mem::forget(sock); diff --git a/src/io/read_packet.rs b/src/io/read_packet.rs index 558dbae0..9c69e15e 100644 --- a/src/io/read_packet.rs +++ b/src/io/read_packet.rs @@ -15,7 +15,7 @@ use std::{ task::{Context, Poll}, }; -use crate::{connection_like::Connection, error::IoError}; +use crate::{buffer_pool::PooledBuf, connection_like::Connection, error::IoError, Conn}; /// Reads a packet. #[derive(Debug)] @@ -26,10 +26,14 @@ impl<'a, 't> ReadPacket<'a, 't> { pub(crate) fn new>>(conn: T) -> Self { Self(conn.into()) } + + pub(crate) fn conn_ref(&self) -> &Conn { + &*self.0 + } } impl Future for ReadPacket<'_, '_> { - type Output = std::result::Result, IoError>; + type Output = std::result::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let packet_opt = match self.0.stream_mut() { diff --git a/src/io/write_packet.rs b/src/io/write_packet.rs index 2f1f48a0..0449edb1 100644 --- a/src/io/write_packet.rs +++ b/src/io/write_packet.rs @@ -16,18 +16,18 @@ use std::{ task::{Context, Poll}, }; -use crate::{connection_like::Connection, error::IoError}; +use crate::{buffer_pool::PooledBuf, connection_like::Connection, error::IoError}; /// Writes a packet. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct WritePacket<'a, 't> { conn: Connection<'a, 't>, - data: Option>, + data: Option, } impl<'a, 't> WritePacket<'a, 't> { - pub(crate) fn new>>(conn: T, data: Vec) -> Self { + pub(crate) fn new>>(conn: T, data: PooledBuf) -> Self { Self { conn: conn.into(), data: Some(data), diff --git a/src/lib.rs b/src/lib.rs index ab525afe..bdb7b428 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,7 +100,9 @@ extern crate test; pub use mysql_common::{chrono, constants as consts, params, time, uuid}; -use std::{future::Future, pin::Pin}; +use std::sync::Arc; + +mod buffer_pool; #[macro_use] mod macros; @@ -114,33 +116,13 @@ mod opts; mod query; mod queryable; -#[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct BoxFuture<'a, T>(Pin> + Send + 'a>>); - -impl Future for BoxFuture<'_, T> { - type Output = Result; +type BoxFuture<'a, T> = futures_core::future::BoxFuture<'a, Result>; - fn poll( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - self.0.as_mut().poll(cx) - } -} - -impl<'a, T> std::fmt::Debug for BoxFuture<'a, T> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_tuple("BoxFuture") - .field(&format!( - "dyn Future", - std::any::type_name::() - )) - .finish() - } -} +static BUFFER_POOL: once_cell::sync::Lazy> = + once_cell::sync::Lazy::new(|| Default::default()); #[doc(inline)] -pub use self::conn::Conn; +pub use self::conn::{binlog_stream::BinlogStream, Conn}; #[doc(inline)] pub use self::conn::pool::Pool; @@ -164,7 +146,22 @@ pub use self::opts::{ pub use self::local_infile_handler::{builtin::WhiteListFsLocalInfileHandler, InfileHandlerFuture}; #[doc(inline)] -pub use mysql_common::packets::Column; +pub use mysql_common::packets::{ + binlog_request::BinlogRequest, + session_state_change::{ + Gtids, Schema, SessionStateChange, SystemVariable, TransactionCharacteristics, + TransactionState, Unsupported, + }, + BinlogDumpFlags, Column, Interval, OkPacket, SessionStateInfo, Sid, +}; + +pub mod binlog { + #[doc(inline)] + pub use mysql_common::binlog::consts::*; + + #[doc(inline)] + pub use mysql_common::binlog::{events, jsonb, jsondiff, row, value}; +} #[doc(inline)] pub use mysql_common::proto::codec::Compression; @@ -287,7 +284,7 @@ pub mod test_misc { } pub fn get_opts() -> OptsBuilder { - let mut builder = OptsBuilder::from_opts(&**DATABASE_URL); + let mut builder = OptsBuilder::from_opts(Opts::from_url(&**DATABASE_URL).unwrap()); if test_ssl() { let ssl_opts = SslOpts::default() .with_danger_skip_domain_validation(true) diff --git a/src/opts.rs b/src/opts.rs index a4dac7f3..2f2e2c8d 100644 --- a/src/opts.rs +++ b/src/opts.rs @@ -11,6 +11,7 @@ use url::{Host, Url}; use std::{ borrow::Cow, + convert::TryFrom, io, net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs}, path::Path, @@ -387,6 +388,23 @@ pub(crate) struct MysqlOpts { /// /// Note that compression level defined here will affect only outgoing packets. compression: Option, + + /// Client side `max_allowed_packet` value (defaults to `None`). + /// + /// By default `Conn` will query this value from the server. One can avoid this step + /// by explicitly specifying it. + max_allowed_packet: Option, + + /// Client side `wait_timeout` value (defaults to `None`). + /// + /// By default `Conn` will query this value from the server. One can avoid this step + /// by explicitly specifying it. + wait_timeout: Option, + + /// Disables `mysql_old_password` plugin (defaults to `true`). + /// + /// Available via `secure_auth` connection url parameter. + secure_auth: bool, } /// Mysql connection options. @@ -657,6 +675,33 @@ impl Opts { self.inner.mysql_opts.compression } + /// Client side `max_allowed_packet` value (defaults to `None`). + /// + /// By default `Conn` will query this value from the server. One can avoid this step + /// by explicitly specifying it. Server side default is 4MB. + /// + /// Available in connection URL via `max_allowed_packet` parameter. + pub fn max_allowed_packet(&self) -> Option { + self.inner.mysql_opts.max_allowed_packet + } + + /// Client side `wait_timeout` value (defaults to `None`). + /// + /// By default `Conn` will query this value from the server. One can avoid this step + /// by explicitly specifying it. Server side default is 28800. + /// + /// Available in connection URL via `wait_timeout` parameter. + pub fn wait_timeout(&self) -> Option { + self.inner.mysql_opts.wait_timeout + } + + /// Disables `mysql_old_password` plugin (defaults to `true`). + /// + /// Available via `secure_auth` connection url parameter. + pub fn secure_auth(&self) -> bool { + self.inner.mysql_opts.secure_auth + } + pub(crate) fn get_capabilities(&self) -> CapabilityFlags { let mut out = CapabilityFlags::CLIENT_PROTOCOL_41 | CapabilityFlags::CLIENT_SECURE_CONNECTION @@ -700,6 +745,9 @@ impl Default for MysqlOpts { prefer_socket: cfg!(not(target_os = "windows")), socket: None, compression: None, + max_allowed_packet: None, + wait_timeout: None, + secure_auth: true, } } } @@ -797,8 +845,16 @@ impl Default for OptsBuilder { impl OptsBuilder { /// Creates new builder from the given `Opts`. - pub fn from_opts>(opts: T) -> Self { - let opts = opts.into(); + /// + /// # Panic + /// + /// It'll panic if `Opts::try_from(opts)` returns error. + pub fn from_opts(opts: T) -> Self + where + Opts: TryFrom, + >::Error: std::error::Error, + { + let opts = Opts::try_from(opts).unwrap(); OptsBuilder { tcp_port: opts.inner.address.get_tcp_port(), @@ -908,6 +964,40 @@ impl OptsBuilder { self.opts.compression = compression.into(); self } + + /// Defines `max_allowed_packet` option. See [`Opts::max_allowed_packet`]. + /// + /// Note that it'll saturate to proper minimum and maximum values + /// for this parameter (see MySql documentation). + pub fn max_allowed_packet(mut self, max_allowed_packet: Option) -> Self { + self.opts.max_allowed_packet = + max_allowed_packet.map(|x| std::cmp::max(1024, std::cmp::min(1073741824, x))); + self + } + + /// Defines `wait_timeout` option. See [`Opts::wait_timeout`]. + /// + /// Note that it'll saturate to proper minimum and maximum values + /// for this parameter (see MySql documentation). + pub fn wait_timeout(mut self, wait_timeout: Option) -> Self { + self.opts.wait_timeout = wait_timeout.map(|x| { + #[cfg(windows)] + let val = std::cmp::min(2147483, x); + #[cfg(not(windows))] + let val = std::cmp::min(31536000, x); + + val + }); + self + } + + /// Disables `mysql_old_password` plugin (defaults to `true`). + /// + /// Available via `secure_auth` connection url parameter. + pub fn secure_auth(mut self, secure_auth: bool) -> Self { + self.opts.secure_auth = secure_auth; + self + } } impl From for Opts { @@ -1060,6 +1150,32 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "max_allowed_packet" { + match usize::from_str(&*value) { + Ok(value) => { + opts.max_allowed_packet = + Some(std::cmp::max(1024, std::cmp::min(1073741824, value))) + } + _ => { + return Err(UrlError::InvalidParamValue { + param: "max_allowed_packet".into(), + value, + }); + } + } + } else if key == "wait_timeout" { + match usize::from_str(&*value) { + #[cfg(windows)] + Ok(value) => opts.wait_timeout = Some(std::cmp::min(2147483, value)), + #[cfg(not(windows))] + Ok(value) => opts.wait_timeout = Some(std::cmp::min(31536000, value)), + _ => { + return Err(UrlError::InvalidParamValue { + param: "wait_timeout".into(), + value, + }); + } + } } else if key == "tcp_nodelay" { match bool::from_str(&*value) { Ok(value) => opts.tcp_nodelay = value, @@ -1094,6 +1210,18 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "secure_auth" { + match bool::from_str(&*value) { + Ok(secure_auth) => { + opts.secure_auth = secure_auth; + } + _ => { + return Err(UrlError::InvalidParamValue { + param: "secure_auth".into(), + value, + }); + } + } } else if key == "socket" { opts.socket = Some(value) } else if key == "compression" { @@ -1138,9 +1266,11 @@ impl FromStr for Opts { } } -impl + Sized> From for Opts { - fn from(url: T) -> Opts { - Opts::from_url(url.as_ref()).unwrap() +impl<'a> TryFrom<&'a str> for Opts { + type Error = UrlError; + + fn try_from(s: &str) -> std::result::Result { + Opts::from_url(s) } } @@ -1220,21 +1350,21 @@ mod test { #[should_panic] fn should_panic_on_invalid_url() { let opts = "42"; - let _: Opts = opts.into(); + let _: Opts = Opts::from_str(opts).unwrap(); } #[test] #[should_panic] fn should_panic_on_invalid_scheme() { let opts = "postgres://localhost"; - let _: Opts = opts.into(); + let _: Opts = Opts::from_str(opts).unwrap(); } #[test] #[should_panic] fn should_panic_on_unknown_query_param() { let opts = "mysql://localhost/foo?bar=baz"; - let _: Opts = opts.into(); + let _: Opts = Opts::from_str(opts).unwrap(); } #[test] diff --git a/src/query.rs b/src/query.rs index 1d0c659b..f2b97ca8 100644 --- a/src/query.rs +++ b/src/query.rs @@ -6,6 +6,8 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. +use futures_util::FutureExt; + use crate::{ connection_like::ToConnectionResult, from_row, @@ -58,7 +60,7 @@ pub trait Query: Send + Sized { C: ToConnection<'a, 't> + 'a, T: FromRow + Send + 'static, { - BoxFuture(Box::pin(async move { + async move { let mut result = self.run(conn).await?; let output = if result.is_empty() { None @@ -67,7 +69,8 @@ pub trait Query: Send + Sized { }; result.drop_result().await?; Ok(output) - })) + } + .boxed() } /// This methods corresponds to [`Queryable::query`][query]. @@ -79,9 +82,7 @@ pub trait Query: Send + Sized { C: ToConnection<'a, 't> + 'a, T: FromRow + Send + 'static, { - BoxFuture(Box::pin(async move { - self.run(conn).await?.collect_and_drop::().await - })) + async move { self.run(conn).await?.collect_and_drop::().await }.boxed() } /// This methods corresponds to [`Queryable::query_fold`][query_fold]. @@ -95,9 +96,7 @@ pub trait Query: Send + Sized { T: FromRow + Send + 'static, U: Send + 'a, { - BoxFuture(Box::pin(async move { - self.run(conn).await?.reduce_and_drop(init, next).await - })) + async move { self.run(conn).await?.reduce_and_drop(init, next).await }.boxed() } /// This methods corresponds to [`Queryable::query_map`][query_map]. @@ -111,12 +110,13 @@ pub trait Query: Send + Sized { T: FromRow + Send + 'static, U: Send + 'a, { - BoxFuture(Box::pin(async move { + async move { self.run(conn) .await? .map_and_drop(|row| map(from_row(row))) .await - })) + } + .boxed() } /// This method corresponds to [`Queryable::query_drop`][query_drop]. @@ -127,9 +127,7 @@ pub trait Query: Send + Sized { Self: 'a, C: ToConnection<'a, 't> + 'a, { - BoxFuture(Box::pin(async move { - self.run(conn).await?.drop_result().await - })) + async move { self.run(conn).await?.drop_result().await }.boxed() } } @@ -141,14 +139,15 @@ impl + Send + Sync> Query for Q { Self: 'a, C: ToConnection<'a, 't> + 'a, { - BoxFuture(Box::pin(async move { + async move { let mut conn = match conn.to_connection() { ToConnectionResult::Immediate(conn) => conn, ToConnectionResult::Mediate(fut) => fut.await?, }; conn.raw_query(self).await?; Ok(QueryResult::new(conn)) - })) + } + .boxed() } } @@ -187,7 +186,7 @@ where Self: 'a, C: ToConnection<'a, 't> + 'a, { - BoxFuture(Box::pin(async move { + async move { let mut conn = match conn.to_connection() { ToConnectionResult::Immediate(conn) => conn, ToConnectionResult::Mediate(fut) => fut.await?, @@ -199,7 +198,8 @@ where .await?; Ok(QueryResult::new(conn)) - })) + } + .boxed() } } @@ -246,7 +246,7 @@ where Self: 'a, C: ToConnection<'a, 't> + 'a, { - BoxFuture(Box::pin(async move { + async move { let mut conn = match conn.to_connection() { ToConnectionResult::Immediate(conn) => conn, ToConnectionResult::Mediate(fut) => fut.await?, @@ -259,7 +259,8 @@ where } Ok(()) - })) + } + .boxed() } } diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index 14796718..c6b4cf5f 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -6,10 +6,13 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. +use futures_util::FutureExt; use mysql_common::{ - packets::{parse_ok_packet, OkPacketKind}, - row::new_row, - value::{read_bin_values, read_text_values, ServerSide}, + io::ParseBuf, + packets::{OkPacketDeserializer, ResultSetTerminator}, + proto::{Binary, Text}, + row::RowDeserializer, + value::ServerSide, }; use std::{fmt, sync::Arc}; @@ -38,7 +41,10 @@ pub trait Protocol: fmt::Debug + Send + Sync + 'static { fn result_set_meta(columns: Arc<[Column]>) -> ResultSetMeta; fn read_result_set_row(packet: &[u8], columns: Arc<[Column]>) -> Result; fn is_last_result_set_packet(capabilities: CapabilityFlags, packet: &[u8]) -> bool { - parse_ok_packet(packet, capabilities, OkPacketKind::ResultSetTerminator).is_ok() + packet.len() < 8 + && ParseBuf(packet) + .parse::>(capabilities) + .is_ok() } } @@ -56,8 +62,9 @@ impl Protocol for TextProtocol { } fn read_result_set_row(packet: &[u8], columns: Arc<[Column]>) -> Result { - read_text_values(packet, columns.len()) - .map(|values| new_row(values, columns)) + ParseBuf(packet) + .parse::>(columns) + .map(Into::into) .map_err(Into::into) } } @@ -68,8 +75,9 @@ impl Protocol for BinaryProtocol { } fn read_result_set_row(packet: &[u8], columns: Arc<[Column]>) -> Result { - read_bin_values::(packet, &*columns) - .map(|values| new_row(values, columns)) + ParseBuf(packet) + .parse::>(columns) + .map(Into::into) .map_err(Into::into) } } @@ -255,10 +263,11 @@ pub trait Queryable: Send { impl Queryable for Conn { fn ping(&mut self) -> BoxFuture<'_, ()> { - BoxFuture(Box::pin(async move { + async move { self.routine(PingRoutine).await?; Ok(()) - })) + } + .boxed() } fn query_iter<'a, Q>( @@ -268,27 +277,27 @@ impl Queryable for Conn { where Q: AsRef + Send + Sync + 'a, { - BoxFuture(Box::pin(async move { + async move { self.routine(QueryRoutine::new(query.as_ref().as_bytes())) .await?; Ok(QueryResult::new(self)) - })) + } + .boxed() } fn prep<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Statement> where Q: AsRef + Sync + Send + 'a, { - BoxFuture(Box::pin( - async move { self.get_statement(query.as_ref()).await }, - )) + async move { self.get_statement(query.as_ref()).await }.boxed() } fn close(&mut self, stmt: Statement) -> BoxFuture<'_, ()> { - BoxFuture(Box::pin(async move { + async move { self.stmt_cache_mut().remove(stmt.id()); self.close_statement(stmt.id()).await - })) + } + .boxed() } fn exec_iter<'a: 's, 's, Q, P>( @@ -301,11 +310,12 @@ impl Queryable for Conn { P: Into, { let params = params.into(); - BoxFuture(Box::pin(async move { + async move { let statement = self.get_statement(stmt).await?; self.execute_statement(&statement, params).await?; Ok(QueryResult::new(self)) - })) + } + .boxed() } fn query<'a, T, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Vec> @@ -313,9 +323,7 @@ impl Queryable for Conn { Q: AsRef + Send + Sync + 'a, T: FromRow + Send + 'static, { - BoxFuture(Box::pin(async move { - self.query_iter(query).await?.collect_and_drop::().await - })) + async move { self.query_iter(query).await?.collect_and_drop::().await }.boxed() } fn query_first<'a, T, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Option> @@ -323,7 +331,7 @@ impl Queryable for Conn { Q: AsRef + Send + Sync + 'a, T: FromRow + Send + 'static, { - BoxFuture(Box::pin(async move { + async move { let mut result = self.query_iter(query).await?; let output = if result.is_empty() { None @@ -332,7 +340,8 @@ impl Queryable for Conn { }; result.drop_result().await?; Ok(output) - })) + } + .boxed() } fn query_map<'a, T, F, Q, U>(&'a mut self, query: Q, mut f: F) -> BoxFuture<'a, Vec> @@ -342,13 +351,14 @@ impl Queryable for Conn { F: FnMut(T) -> U + Send + 'a, U: Send, { - BoxFuture(Box::pin(async move { + async move { self.query_fold(query, Vec::new(), |mut acc, row| { acc.push(f(crate::from_row(row))); acc }) .await - })) + } + .boxed() } fn query_fold<'a, T, F, Q, U>(&'a mut self, query: Q, init: U, mut f: F) -> BoxFuture<'a, U> @@ -358,21 +368,20 @@ impl Queryable for Conn { F: FnMut(U, T) -> U + Send + 'a, U: Send + 'a, { - BoxFuture(Box::pin(async move { + async move { self.query_iter(query) .await? .reduce_and_drop(init, |acc, row| f(acc, crate::from_row(row))) .await - })) + } + .boxed() } fn query_drop<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, ()> where Q: AsRef + Send + Sync + 'a, { - BoxFuture(Box::pin(async move { - self.query_iter(query).await?.drop_result().await - })) + async move { self.query_iter(query).await?.drop_result().await }.boxed() } fn exec_batch<'a: 'b, 'b, S, P, I>(&'a mut self, stmt: S, params_iter: I) -> BoxFuture<'b, ()> @@ -382,7 +391,7 @@ impl Queryable for Conn { I::IntoIter: Send, P: Into + Send, { - BoxFuture(Box::pin(async move { + async move { let statement = self.get_statement(stmt).await?; for params in params_iter { self.execute_statement(&statement, params).await?; @@ -391,7 +400,8 @@ impl Queryable for Conn { .await?; } Ok(()) - })) + } + .boxed() } fn exec<'a: 'b, 'b, T, S, P>(&'a mut self, stmt: S, params: P) -> BoxFuture<'b, Vec> @@ -400,12 +410,13 @@ impl Queryable for Conn { P: Into + Send + 'b, T: FromRow + Send + 'static, { - BoxFuture(Box::pin(async move { + async move { self.exec_iter(stmt, params) .await? .collect_and_drop::() .await - })) + } + .boxed() } fn exec_first<'a: 'b, 'b, T, S, P>(&'a mut self, stmt: S, params: P) -> BoxFuture<'b, Option> @@ -414,7 +425,7 @@ impl Queryable for Conn { P: Into + Send + 'b, T: FromRow + Send + 'static, { - BoxFuture(Box::pin(async move { + async move { let mut result = self.exec_iter(stmt, params).await?; let row = if result.is_empty() { None @@ -423,7 +434,8 @@ impl Queryable for Conn { }; result.drop_result().await?; Ok(row.map(crate::from_row)) - })) + } + .boxed() } fn exec_map<'a: 'b, 'b, T, S, P, U, F>( @@ -439,13 +451,14 @@ impl Queryable for Conn { F: FnMut(T) -> U + Send + 'a, U: Send + 'a, { - BoxFuture(Box::pin(async move { + async move { self.exec_fold(stmt, params, Vec::new(), |mut acc, row| { acc.push(f(crate::from_row(row))); acc }) .await - })) + } + .boxed() } fn exec_fold<'a: 'b, 'b, T, S, P, U, F>( @@ -462,12 +475,13 @@ impl Queryable for Conn { F: FnMut(U, T) -> U + Send + 'a, U: Send + 'a, { - BoxFuture(Box::pin(async move { + async move { self.exec_iter(stmt, params) .await? .reduce_and_drop(init, |acc, row| f(acc, crate::from_row(row))) .await - })) + } + .boxed() } fn exec_drop<'a: 'b, 'b, S, P>(&'a mut self, stmt: S, params: P) -> BoxFuture<'b, ()> @@ -475,9 +489,7 @@ impl Queryable for Conn { S: StatementLike + 'b, P: Into + Send + 'b, { - BoxFuture(Box::pin(async move { - self.exec_iter(stmt, params).await?.drop_result().await - })) + async move { self.exec_iter(stmt, params).await?.drop_result().await }.boxed() } } diff --git a/src/queryable/stmt.rs b/src/queryable/stmt.rs index 066587da..e601264a 100644 --- a/src/queryable/stmt.rs +++ b/src/queryable/stmt.rs @@ -6,9 +6,11 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. +use futures_util::FutureExt; use mysql_common::{ + io::ParseBuf, named_params::parse_named_params, - packets::{column_from_payload, parse_stmt_packet, ComStmtClose, StmtPacket}, + packets::{ComStmtClose, StmtPacket}, }; use std::{borrow::Cow, sync::Arc}; @@ -39,14 +41,15 @@ fn to_statement_move<'a, T: AsRef + Send + Sync + 'a>( stmt: T, conn: &'a mut crate::Conn, ) -> ToStatementResult<'a> { - let fut = crate::BoxFuture(Box::pin(async move { + let fut = async move { let (named_params, raw_query) = parse_named_params(stmt.as_ref())?; let inner_stmt = match conn.get_cached_stmt(&*raw_query) { Some(inner_stmt) => inner_stmt, None => conn.prepare_statement(raw_query).await?, }; Ok(Statement::new(inner_stmt, named_params)) - })); + } + .boxed(); ToStatementResult::Mediate(fut) } @@ -129,7 +132,7 @@ impl StmtInner { connection_id: u32, raw_query: Arc, ) -> std::io::Result { - let stmt_packet = parse_stmt_packet(pld)?; + let stmt_packet = ParseBuf(pld).parse(())?; Ok(Self { raw_query, @@ -244,7 +247,7 @@ impl crate::Conn { let packets = self.read_packets(num).await?; let defs = packets .into_iter() - .map(column_from_payload) + .map(|x| ParseBuf(&*x).parse(())) .collect::, _>>() .map_err(Error::from)?; @@ -299,6 +302,6 @@ impl crate::Conn { /// Helper, that closes statement with the given id. pub(crate) async fn close_statement(&mut self, id: u32) -> Result<()> { self.stmt_cache_mut().remove(id); - self.write_command_raw(ComStmtClose::new(id).into()).await + self.write_command(&ComStmtClose::new(id)).await } } diff --git a/tests/exports.rs b/tests/exports.rs index e6ff0466..fa75224f 100644 --- a/tests/exports.rs +++ b/tests/exports.rs @@ -7,9 +7,9 @@ use mysql_async::{ BatchQuery, ConvIr, FromRow, FromValue, LocalInfileHandler, Protocol, Query, Queryable, StatementLike, ToValue, }, - time, uuid, BinaryProtocol, BoxFuture, Column, Conn, Deserialized, DriverError, Error, - FromRowError, FromValueError, IoError, IsolationLevel, Opts, OptsBuilder, Params, ParseError, - Pool, PoolConstraints, PoolOpts, QueryResult, Result, Row, Serialized, ServerError, SslOpts, + time, uuid, BinaryProtocol, Column, Conn, Deserialized, DriverError, Error, FromRowError, + FromValueError, IoError, IsolationLevel, Opts, OptsBuilder, Params, ParseError, Pool, + PoolConstraints, PoolOpts, QueryResult, Result, Row, Serialized, ServerError, SslOpts, Statement, TextProtocol, Transaction, TxOpts, UrlError, Value, WhiteListFsLocalInfileHandler, DEFAULT_INACTIVE_CONNECTION_TTL, DEFAULT_TTL_CHECK_INTERVAL, };