From d2099e93e10166fcd3260f72f81b9597e87756f3 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Sat, 23 Apr 2022 20:17:34 +0300 Subject: [PATCH] Simplify and document LOCAL INFILE handlers. --- Cargo.toml | 2 +- README.md | 187 ++++++++++++++++++++++++++-- README.tpl | 30 +++++ src/conn/mod.rs | 75 +++++++++-- src/conn/pool/mod.rs | 2 + src/conn/routines/helpers.rs | 40 +++--- src/error.rs | 51 +++++--- src/lib.rs | 120 ++++++++++++++++-- src/local_infile_handler/builtin.rs | 50 ++++---- src/local_infile_handler/mod.rs | 133 ++++++++------------ src/opts.rs | 12 +- tests/exports.rs | 6 +- 12 files changed, 537 insertions(+), 171 deletions(-) create mode 100644 README.tpl diff --git a/Cargo.toml b/Cargo.toml index 54d6390b..943044f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ serde_json = "1" socket2 = "0.4.2" thiserror = "1.0.4" tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] } -tokio-util = { version = "0.6.0", features = ["codec"] } +tokio-util = { version = "0.6.0", features = ["codec", "io"] } tokio-native-tls = "0.3.0" twox-hash = "1" url = "2.1" diff --git a/README.md b/README.md index 3042f538..40190734 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,3 @@ -## mysql-async - -Tokio based asynchronous MySql client library for The Rust Programming Language. - [![Gitter](https://badges.gitter.im/rust-mysql/community.svg)](https://gitter.im/rust-mysql/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [![Build Status](https://dev.azure.com/aikorsky/mysql%20Rust/_apis/build/status/blackbeam.mysql_async?branchName=master)](https://dev.azure.com/aikorsky/mysql%20Rust/_build/latest?definitionId=2&branchName=master) @@ -9,27 +5,192 @@ Tokio based asynchronous MySql client library for The Rust Programming Language. [![](https://img.shields.io/crates/d/mysql_async.svg)](https://crates.io/crates/mysql_async) [![API Documentation on docs.rs](https://docs.rs/mysql_async/badge.svg)](https://docs.rs/mysql_async) -### Installation +# mysql_async + +Tokio based asynchronous MySql client library for The Rust Programming Language. + +## Installation The library is hosted on [crates.io](https://crates.io/crates/mysql_async/). + ```toml [dependencies] mysql_async = "" ``` -### Example +## Example + +```rust +use mysql_async::prelude::*; + +#[derive(Debug, PartialEq, Eq, Clone)] +struct Payment { + customer_id: i32, + amount: i32, + account_name: Option, +} + +#[tokio::main] +async fn main() -> Result<()> { + let payments = vec![ + Payment { customer_id: 1, amount: 2, account_name: None }, + Payment { customer_id: 3, amount: 4, account_name: Some("foo".into()) }, + Payment { customer_id: 5, amount: 6, account_name: None }, + Payment { customer_id: 7, amount: 8, account_name: None }, + Payment { customer_id: 9, amount: 10, account_name: Some("bar".into()) }, + ]; + + let database_url = /* ... */ + # get_opts(); + + let pool = mysql_async::Pool::new(database_url); + let mut conn = pool.get_conn().await?; + + // Create a temporary table + r"CREATE TEMPORARY TABLE payment ( + customer_id int not null, + amount int not null, + account_name text + )".ignore(&mut conn).await?; + + // Save payments + r"INSERT INTO payment (customer_id, amount, account_name) + VALUES (:customer_id, :amount, :account_name)" + .with(payments.iter().map(|payment| params! { + "customer_id" => payment.customer_id, + "amount" => payment.amount, + "account_name" => payment.account_name.as_ref(), + })) + .batch(&mut conn) + .await?; + + // Load payments from the database. Type inference will work here. + let loaded_payments = "SELECT customer_id, amount, account_name FROM payment" + .with(()) + .map(&mut conn, |(customer_id, amount, account_name)| Payment { customer_id, amount, account_name }) + .await?; + + // Dropped connection will go to the pool + drop(conn); + + // The Pool must be disconnected explicitly because + // it's an asynchronous operation. + pool.disconnect().await?; + + assert_eq!(loaded_payments, payments); + + // the async fn returns Result, so + Ok(()) +} +``` + +## LOCAL INFILE Handlers + +**Warning:** You should be aware of [Security Considerations for LOAD DATA LOCAL][1]. + +There are two flavors of LOCAL INFILE handlers – _global_ and _local_. + +I case of a LOCAL INFILE request from the server the driver will try to find a handler for it: + +1. It'll try to use _local_ handler installed on the connection, if any; +2. It'll try to use _global_ handler, specified via [`OptsBuilder::local_infile_handler`], + if any; +3. It will emit [`LocalInfileError::NoHandler`] if no handlers found. + +The purpose of a handler (_local_ or _global_) is to return [`InfileData`]. + +### _Global_ LOCAL INFILE handler + +See [`prelude::GlobalHandler`]. + +Simply speaking the _global_ handler is an async function that takes a file name (as `&[u8]`) +and returns `Result`. + +You can set it up using [`OptsBuilder::local_infile_handler`]. Server will use it if there is no +_local_ handler installed for the connection. This handler might be called multiple times. + +Examles: + +1. [`WhiteListFsHandler`] is a _global_ handler. +2. Every `T: Fn(&[u8]) -> BoxFuture<'static, Result>` + is a _global_ handler. + +### _Local_ LOCAL INFILE handler. + +Simply speaking the _local_ handler is a future, that returns `Result`. + +This is a one-time handler – it's consumed after use. You can set it up using +[`Conn::set_infile_handler`]. This handler have priority over _global_ handler. + +Worth noting: + +1. `impl Drop for Conn` will clear _local_ handler, i.e. handler will be removed when + connection is returned to a `Pool`. +2. [`Conn::reset`] will clear _local_ handler. + +Example: + +```rust +# +let pool = mysql_async::Pool::new(database_url); + +let mut conn = pool.get_conn().await?; +"CREATE TEMPORARY TABLE tmp (id INT, val TEXT)".ignore(&mut conn).await?; + +// We are going to call `LOAD DATA LOCAL` so let's setup a one-time handler. +conn.set_infile_handler(async move { + // We need to return a stream of `io::Result` + Ok(stream::iter([Bytes::from("1,a\r\n"), Bytes::from("2,b\r\n3,c")]).map(Ok).boxed()) +}); + +let result = r#"LOAD DATA LOCAL INFILE 'whatever' + INTO TABLE `tmp` + FIELDS TERMINATED BY ',' ENCLOSED BY '\"' + LINES TERMINATED BY '\r\n'"#.ignore(&mut conn).await; + +match result { + Ok(()) => (), + Err(Error::Server(ref err)) if err.code == 1148 => { + // The used command is not allowed with this MySQL version + return Ok(()); + }, + Err(Error::Server(ref err)) if err.code == 3948 => { + // Loading local data is disabled; + // this must be enabled on both the client and the server + return Ok(()); + } + e @ Err(_) => e.unwrap(), +} + +// Now let's verify the result +let result: Vec<(u32, String)> = conn.query("SELECT * FROM tmp ORDER BY id ASC").await?; +assert_eq!( + result, + vec![(1, "a".into()), (2, "b".into()), (3, "c".into())] +); + +drop(conn); +pool.disconnect().await?; +``` + +[1]: https://dev.mysql.com/doc/refman/8.0/en/load-data-local-security.html -Please see the crate docs – [docs.rs](https://docs.rs/mysql_async). +## Change log -### License +Available [here](https://github.com/blackbeam/mysql_async/releases) + +## License Licensed under either of - * Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) - * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) + +* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or https://www.apache.org/licenses/LICENSE-2.0) +* MIT license ([LICENSE-MIT](LICENSE-MIT) or https://opensource.org/licenses/MIT) + at your option. ### Contribution -Unless you explicitly state otherwise, any contribution intentionally submitted -for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any -additional terms or conditions. +Unless you explicitly state otherwise, any contribution intentionally +submitted for inclusion in the work by you, as defined in the Apache-2.0 +license, shall be dual licensed as above, without any additional terms or +conditions. diff --git a/README.tpl b/README.tpl new file mode 100644 index 00000000..889ba558 --- /dev/null +++ b/README.tpl @@ -0,0 +1,30 @@ +[![Gitter](https://badges.gitter.im/rust-mysql/community.svg)](https://gitter.im/rust-mysql/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) + +[![Build Status](https://dev.azure.com/aikorsky/mysql%20Rust/_apis/build/status/blackbeam.mysql_async?branchName=master)](https://dev.azure.com/aikorsky/mysql%20Rust/_build/latest?definitionId=2&branchName=master) +[![](https://meritbadge.herokuapp.com/mysql_async)](https://crates.io/crates/mysql_async) +[![](https://img.shields.io/crates/d/mysql_async.svg)](https://crates.io/crates/mysql_async) +[![API Documentation on docs.rs](https://docs.rs/mysql_async/badge.svg)](https://docs.rs/mysql_async) + +# {{crate}} + +{{readme}} + +## Change log + +Available [here](https://github.com/blackbeam/mysql_async/releases) + +## License + +Licensed under either of + +* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or https://www.apache.org/licenses/LICENSE-2.0) +* MIT license ([LICENSE-MIT](LICENSE-MIT) or https://opensource.org/licenses/MIT) + +at your option. + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally +submitted for inclusion in the work by you, as defined in the Apache-2.0 +license, shall be dual licensed as above, without any additional terms or +conditions. \ No newline at end of file diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 2a2caf1c..436afff6 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -44,7 +44,7 @@ use crate::{ transaction::TxStatus, BinaryProtocol, Queryable, TextProtocol, }, - BinlogStream, OptsBuilder, + BinlogStream, InfileData, OptsBuilder, }; use self::routines::Routine; @@ -111,6 +111,9 @@ struct ConnInner { auth_switched: bool, /// Connection is already disconnected. pub(crate) disconnected: bool, + /// One-time connection-level infile handler. + infile_handler: + Option> + Send + Sync + 'static>>>, } impl fmt::Debug for ConnInner { @@ -151,6 +154,7 @@ impl ConnInner { auth_plugin: AuthPlugin::MysqlNativePassword, auth_switched: false, disconnected: false, + infile_handler: None, } } @@ -374,6 +378,20 @@ impl Conn { &self.inner.opts } + /// Setup _local_ `LOCAL INFILE` handler (see ["LOCAL INFILE Handlers"][2] section + /// of the crate-level docs). + /// + /// It'll overwrite existing _local_ handler, if any. + /// + /// [2]: ../mysql_async/#local-infile-handlers + pub fn set_infile_handler(&mut self, handler: T) + where + T: Future>, + T: Send + Sync + 'static, + { + self.inner.infile_handler = Some(Box::pin(handler)); + } + fn take_stream(&mut self) -> Stream { self.inner.stream.take().unwrap() } @@ -911,6 +929,7 @@ impl Conn { }; self.inner.stmt_cache.clear(); + self.inner.infile_handler = None; self.inner.pool = pool; Ok(()) } @@ -1022,7 +1041,8 @@ impl Conn { #[cfg(test)] mod test { - use futures_util::stream::StreamExt; + use bytes::Bytes; + use futures_util::stream::{self, StreamExt}; use mysql_common::binlog::events::EventData; use tokio::time::timeout; @@ -1030,7 +1050,7 @@ mod test { use crate::{ from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest, Conn, - Error, OptsBuilder, Pool, WhiteListFsLocalInfileHandler, + Error, OptsBuilder, Pool, WhiteListFsHandler, }; async fn gen_dummy_data() -> super::Result<()> { @@ -1676,8 +1696,7 @@ mod test { write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?; - let opts = get_opts() - .local_infile_handler(Some(WhiteListFsLocalInfileHandler::new(&[file_name][..]))); + let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..]))); // LOCAL INFILE in the middle of a multi-result set should not break anything. let mut conn = Conn::new(opts).await.unwrap(); @@ -1802,7 +1821,48 @@ mod test { } #[tokio::test] - async fn should_handle_local_infile() -> super::Result<()> { + async fn should_handle_local_infile_locally() -> super::Result<()> { + let mut conn = Conn::new(get_opts()).await.unwrap(); + conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);") + .await + .unwrap(); + + conn.set_infile_handler(async move { + Ok( + stream::iter([Bytes::from("AAAAAA\n"), Bytes::from("BBBBBB\nCCCCCC\n")]) + .map(Ok) + .boxed(), + ) + }); + + match conn + .query_drop(r#"LOAD DATA LOCAL INFILE "dummy" INTO TABLE tmp;"#) + .await + { + Ok(_) => (), + Err(super::Error::Server(ref err)) if err.code == 1148 => { + // The used command is not allowed with this MySQL version + return Ok(()); + } + Err(super::Error::Server(ref err)) if err.code == 3948 => { + // Loading local data is disabled; + // this must be enabled on both the client and server sides + return Ok(()); + } + e @ Err(_) => e.unwrap(), + }; + + let result: Vec = conn.query("SELECT * FROM tmp").await?; + assert_eq!(result.len(), 3); + assert_eq!(result[0], "AAAAAA"); + assert_eq!(result[1], "BBBBBB"); + assert_eq!(result[2], "CCCCCC"); + + Ok(()) + } + + #[tokio::test] + async fn should_handle_local_infile_globally() -> super::Result<()> { use std::fs::write; let file_path = tempfile::Builder::new().tempfile_in("").unwrap(); @@ -1811,8 +1871,7 @@ mod test { write(file_name, b"AAAAAA\nBBBBBB\nCCCCCC\n")?; - let opts = get_opts() - .local_infile_handler(Some(WhiteListFsLocalInfileHandler::new(&[file_name][..]))); + let opts = get_opts().local_infile_handler(Some(WhiteListFsHandler::new(&[file_name][..]))); let mut conn = Conn::new(opts).await.unwrap(); conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);") diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index c4818f6a..ffbc2b89 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -284,6 +284,8 @@ impl Pool { impl Drop for Conn { fn drop(&mut self) { + self.inner.infile_handler = None; + if std::thread::panicking() { // Try to decrease the number of existing connections. if let Some(pool) = self.inner.pool.take() { diff --git a/src/conn/routines/helpers.rs b/src/conn/routines/helpers.rs index 0bf5dba9..eed0f38e 100644 --- a/src/conn/routines/helpers.rs +++ b/src/conn/routines/helpers.rs @@ -2,15 +2,15 @@ use std::sync::Arc; +use futures_util::StreamExt; use mysql_common::{ constants::MAX_PAYLOAD_LEN, io::{ParseBuf, ReadMysqlExt}, packets::{ComStmtSendLongData, LocalInfilePacket}, value::Value, }; -use tokio::io::AsyncReadExt; -use crate::{queryable::Protocol, Conn, DriverError, Error}; +use crate::{error::LocalInfileError, queryable::Protocol, Conn, Error}; impl Conn { /// Helper, that sends all `Value::Bytes` in the given list of paramenters as long data. @@ -84,27 +84,39 @@ impl Conn { P: Protocol, { let local_infile = ParseBuf(packet).parse::(())?; - let (local_infile, handler) = match self.opts().local_infile_handler() { - Some(handler) => ((local_infile.into_owned(), handler)), - None => return Err(DriverError::NoLocalInfileHandler.into()), - }; - let mut reader = handler.handle(local_infile.file_name_ref()).await?; - let mut buf = [0; 4096]; - loop { - let read = reader.read(&mut buf[..]).await?; - self.write_bytes(&buf[..read]).await?; + let mut infile_data = if let Some(handler) = self.inner.infile_handler.take() { + handler.await? + } else if let Some(handler) = self.opts().local_infile_handler() { + handler.handle(local_infile.file_name_ref()).await? + } else { + return Err(LocalInfileError::NoHandler.into()); + }; - if read == 0 { - break; + let mut result = Ok(()); + while let Some(bytes) = infile_data.next().await { + match bytes { + Ok(bytes) => { + // We'll skip empty chunks to stay compliant with the protocol. + if bytes.len() > 0 { + self.write_bytes(&bytes).await?; + } + } + Err(err) => { + // Abort the stream in case of an error. + result = Err(LocalInfileError::from(err)); + break; + } } } + self.write_bytes(&[]).await?; self.read_packet().await?; self.set_pending_result(Some(P::result_set_meta(Arc::from( Vec::new().into_boxed_slice(), ))))?; - Ok(()) + + result.map_err(Into::into) } /// Helper that handles result set packet. diff --git a/src/error.rs b/src/error.rs index db361c17..8d82a96b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -14,7 +14,7 @@ use mysql_common::{ }; use thiserror::Error; -use std::{borrow::Cow, io, result}; +use std::{io, result}; /// Result type alias for this library. pub type Result = result::Result; @@ -29,7 +29,7 @@ pub enum Error { Io(#[source] IoError), #[error("Other error: {}", _0)] - Other(Cow<'static, str>), + Other(#[source] Box), #[error("Server error: `{}'", _0)] Server(#[source] ServerError), @@ -93,7 +93,7 @@ pub enum UrlError { } /// This type enumerates driver errors. -#[derive(Debug, Error, Clone, PartialEq)] +#[derive(Debug, Error)] pub enum DriverError { #[error("Can't parse server version from string `{}'.", version_string)] CantParseServerVersion { version_string: String }, @@ -119,9 +119,6 @@ pub enum DriverError { #[error("Transactions couldn't be nested.")] NestedTransaction, - #[error("Can't handle local infile request. Handler not specified.")] - NoLocalInfileHandler, - #[error("Packet out of order.")] PacketOutOfOrder, @@ -155,6 +152,36 @@ pub enum DriverError { #[error("`mysql_old_password` plugin is insecure and disabled by default")] MysqlOldPasswordDisabled, + + #[error("LOCAL INFILE error: {}", _0)] + LocalInfile(#[from] LocalInfileError), +} + +#[derive(Debug, Error)] +pub enum LocalInfileError { + #[error("The given path is not in the while list: {}", _0)] + PathIsNotInTheWhiteList(String), + #[error("Error reading `INFILE` data: {}", _0)] + ReadError(#[from] io::Error), + #[error("Can't handle local infile request. Handler is not specified.")] + NoHandler, + #[error(transparent)] + OtherError(Box), +} + +impl LocalInfileError { + pub fn other(err: T) -> Self + where + T: std::error::Error + Send + Sync + 'static, + { + Self::OtherError(Box::new(err)) + } +} + +impl From for Error { + fn from(err: LocalInfileError) -> Self { + Self::Driver(err.into()) + } } impl From for Error { @@ -246,18 +273,6 @@ impl From for Error { } } -impl From for Error { - fn from(err: String) -> Self { - Error::Other(Cow::from(err)) - } -} - -impl From<&'static str> for Error { - fn from(err: &'static str) -> Self { - Error::Other(Cow::from(err)) - } -} - impl From for UrlError { fn from(err: ParseError) -> Self { UrlError::Parse(err) diff --git a/src/lib.rs b/src/lib.rs index f71d3b7e..5fd20f13 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,10 +6,10 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. -//! ## mysql-async //! Tokio based asynchronous MySql client library for The Rust Programming Language. //! -//! ### Installation +//! # Installation +//! //! The library is hosted on [crates.io](https://crates.io/crates/mysql_async/). //! //! ```toml @@ -17,7 +17,7 @@ //! mysql_async = "" //! ``` //! -//! ### Example +//! # Example //! //! ```rust //! # use mysql_async::{Result, test_misc::get_opts}; @@ -84,6 +84,107 @@ //! Ok(()) //! } //! ``` +//! +//! # LOCAL INFILE Handlers +//! +//! **Warning:** You should be aware of [Security Considerations for LOAD DATA LOCAL][1]. +//! +//! There are two flavors of LOCAL INFILE handlers – _global_ and _local_. +//! +//! I case of a LOCAL INFILE request from the server the driver will try to find a handler for it: +//! +//! 1. It'll try to use _local_ handler installed on the connection, if any; +//! 2. It'll try to use _global_ handler, specified via [`OptsBuilder::local_infile_handler`], +//! if any; +//! 3. It will emit [`LocalInfileError::NoHandler`] if no handlers found. +//! +//! The purpose of a handler (_local_ or _global_) is to return [`InfileData`]. +//! +//! ## _Global_ LOCAL INFILE handler +//! +//! See [`prelude::GlobalHandler`]. +//! +//! Simply speaking the _global_ handler is an async function that takes a file name (as `&[u8]`) +//! and returns `Result`. +//! +//! You can set it up using [`OptsBuilder::local_infile_handler`]. Server will use it if there is no +//! _local_ handler installed for the connection. This handler might be called multiple times. +//! +//! Examles: +//! +//! 1. [`WhiteListFsHandler`] is a _global_ handler. +//! 2. Every `T: Fn(&[u8]) -> BoxFuture<'static, Result>` +//! is a _global_ handler. +//! +//! ## _Local_ LOCAL INFILE handler. +//! +//! Simply speaking the _local_ handler is a future, that returns `Result`. +//! +//! This is a one-time handler – it's consumed after use. You can set it up using +//! [`Conn::set_infile_handler`]. This handler have priority over _global_ handler. +//! +//! Worth noting: +//! +//! 1. `impl Drop for Conn` will clear _local_ handler, i.e. handler will be removed when +//! connection is returned to a `Pool`. +//! 2. [`Conn::reset`] will clear _local_ handler. +//! +//! Example: +//! +//! ```rust +//! # use mysql_async::{prelude::*, test_misc::get_opts, OptsBuilder, Result, Error}; +//! # use futures_util::future::FutureExt; +//! # use futures_util::stream::{self, StreamExt}; +//! # use bytes::Bytes; +//! # use std::env; +//! # #[tokio::main] +//! # async fn main() -> Result<()> { +//! # +//! # let database_url = get_opts(); +//! let pool = mysql_async::Pool::new(database_url); +//! +//! let mut conn = pool.get_conn().await?; +//! "CREATE TEMPORARY TABLE tmp (id INT, val TEXT)".ignore(&mut conn).await?; +//! +//! // We are going to call `LOAD DATA LOCAL` so let's setup a one-time handler. +//! conn.set_infile_handler(async move { +//! // We need to return a stream of `io::Result` +//! Ok(stream::iter([Bytes::from("1,a\r\n"), Bytes::from("2,b\r\n3,c")]).map(Ok).boxed()) +//! }); +//! +//! let result = r#"LOAD DATA LOCAL INFILE 'whatever' +//! INTO TABLE `tmp` +//! FIELDS TERMINATED BY ',' ENCLOSED BY '\"' +//! LINES TERMINATED BY '\r\n'"#.ignore(&mut conn).await; +//! +//! match result { +//! Ok(()) => (), +//! Err(Error::Server(ref err)) if err.code == 1148 => { +//! // The used command is not allowed with this MySQL version +//! return Ok(()); +//! }, +//! Err(Error::Server(ref err)) if err.code == 3948 => { +//! // Loading local data is disabled; +//! // this must be enabled on both the client and the server +//! return Ok(()); +//! } +//! e @ Err(_) => e.unwrap(), +//! } +//! +//! // Now let's verify the result +//! let result: Vec<(u32, String)> = conn.query("SELECT * FROM tmp ORDER BY id ASC").await?; +//! assert_eq!( +//! result, +//! vec![(1, "a".into()), (2, "b".into()), (3, "c".into())] +//! ); +//! +//! drop(conn); +//! pool.disconnect().await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! [1]: https://dev.mysql.com/doc/refman/8.0/en/load-data-local-security.html #![recursion_limit = "1024"] #![cfg_attr(feature = "nightly", feature(test))] @@ -121,7 +222,9 @@ pub use self::conn::{binlog_stream::BinlogStream, Conn}; pub use self::conn::pool::Pool; #[doc(inline)] -pub use self::error::{DriverError, Error, IoError, ParseError, Result, ServerError, UrlError}; +pub use self::error::{ + DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError, UrlError, +}; #[doc(inline)] pub use self::query::QueryWithParams; @@ -136,7 +239,7 @@ pub use self::opts::{ }; #[doc(inline)] -pub use self::local_infile_handler::{builtin::WhiteListFsLocalInfileHandler, InfileHandlerFuture}; +pub use self::local_infile_handler::{builtin::WhiteListFsHandler, InfileData}; #[doc(inline)] pub use mysql_common::packets::{ @@ -197,7 +300,7 @@ pub mod futures { /// Traits used in this crate pub mod prelude { #[doc(inline)] - pub use crate::local_infile_handler::LocalInfileHandler; + pub use crate::local_infile_handler::GlobalHandler; #[doc(inline)] pub use crate::query::{BatchQuery, Query, WithParams}; #[doc(inline)] @@ -256,9 +359,10 @@ pub mod test_misc { use crate::opts::{Opts, OptsBuilder, SslOpts}; #[allow(dead_code)] + #[allow(unreachable_code)] fn error_should_implement_send_and_sync() { - fn _dummy(_: T) {} - _dummy(crate::error::Error::from("foo")); + fn _dummy(_: T) {} + _dummy(panic!()); } lazy_static! { diff --git a/src/local_infile_handler/builtin.rs b/src/local_infile_handler/builtin.rs index 2638f13e..e8b4ac9b 100644 --- a/src/local_infile_handler/builtin.rs +++ b/src/local_infile_handler/builtin.rs @@ -6,32 +6,40 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. +use futures_util::FutureExt; use tokio::fs::File; +use tokio_util::io::ReaderStream; -use std::{collections::HashSet, path::PathBuf, str::from_utf8}; +use std::{ + collections::HashSet, + path::{Path, PathBuf}, +}; -use crate::local_infile_handler::LocalInfileHandler; +use crate::{ + error::LocalInfileError, + local_infile_handler::{BoxFuture, GlobalHandler}, +}; -/// Handles local infile requests from filesystem using explicit whitelist of paths. +/// Handles `LOCAL INFILE` requests from filesystem using an explicit whitelist of paths. /// /// Example usage: /// /// ```rust -/// use mysql_async::{OptsBuilder, WhiteListFsLocalInfileHandler}; +/// use mysql_async::{OptsBuilder, WhiteListFsHandler}; /// /// # let database_url = "mysql://root:password@127.0.0.1:3307/mysql"; /// let mut opts = OptsBuilder::from_opts(database_url); -/// opts.local_infile_handler(Some(WhiteListFsLocalInfileHandler::new( +/// opts.local_infile_handler(Some(WhiteListFsHandler::new( /// &["path/to/local_infile.txt"][..], /// ))); /// ``` #[derive(Clone, Debug)] -pub struct WhiteListFsLocalInfileHandler { +pub struct WhiteListFsHandler { white_list: HashSet, } -impl WhiteListFsLocalInfileHandler { - pub fn new(white_list: B) -> WhiteListFsLocalInfileHandler +impl WhiteListFsHandler { + pub fn new(white_list: B) -> WhiteListFsHandler where A: Into, B: IntoIterator, @@ -40,24 +48,24 @@ impl WhiteListFsLocalInfileHandler { for path in white_list.into_iter() { white_list_set.insert(Into::::into(path)); } - WhiteListFsLocalInfileHandler { + WhiteListFsHandler { white_list: white_list_set, } } } -impl LocalInfileHandler for WhiteListFsLocalInfileHandler { - fn handle(&self, file_name: &[u8]) -> super::InfileHandlerFuture { - let path: PathBuf = match from_utf8(file_name) { - Ok(path_str) => path_str.into(), - Err(_) => return Box::pin(futures_util::future::err("Invalid file name".into())), - }; - - if !self.white_list.contains(&path) { - let err_msg = format!("Path `{}' is not in white list", path.display()); - return Box::pin(futures_util::future::err(err_msg.into())); +impl GlobalHandler for WhiteListFsHandler { + fn handle(&self, file_name: &[u8]) -> BoxFuture<'static, super::InfileData> { + let path = String::from_utf8_lossy(file_name); + let path = self + .white_list + .get(Path::new(&*path)) + .cloned() + .ok_or_else(|| LocalInfileError::PathIsNotInTheWhiteList(path.into_owned())); + async move { + let file = File::open(path?).await?; + Ok(Box::pin(ReaderStream::new(file)) as super::InfileData) } - - Box::pin(async move { Ok(Box::new(File::open(path.to_owned()).await?) as Box<_>) }) + .boxed() } } diff --git a/src/local_infile_handler/mod.rs b/src/local_infile_handler/mod.rs index 6d35df4a..7d222a2b 100644 --- a/src/local_infile_handler/mod.rs +++ b/src/local_infile_handler/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2017 Anatoly Ikorsky +// Copyright (c) 2017-2022 mysql_async Contributors. // // Licensed under the Apache License, Version 2.0 // or the MIT @@ -6,110 +6,85 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. -use crate::error; -use mysql_common::uuid::Uuid; +use bytes::Bytes; +use futures_core::stream::BoxStream; -use std::{fmt, future::Future, marker::Unpin, pin::Pin, sync::Arc}; -use tokio::io::AsyncRead; +use std::{ + fmt, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use crate::error::LocalInfileError; pub mod builtin; -/// Trait used to handle local infile requests. -/// -/// Be aware of security issues with [LOAD DATA LOCAL][1]. -/// Using [`crate::WhiteListFsLocalInfileHandler`] is advised. -/// -/// Simple handler example: -/// -/// ```rust -/// # use mysql_async::{prelude::*, test_misc::get_opts, OptsBuilder, Result, Error}; -/// # use std::env; -/// # #[tokio::main] -/// # async fn main() -> Result<()> { -/// # -/// /// This example hanlder will return contained bytes in response to a local infile request. -/// struct ExampleHandler(&'static [u8]); -/// -/// impl LocalInfileHandler for ExampleHandler { -/// fn handle(&self, _: &[u8]) -> mysql_async::InfileHandlerFuture { -/// let handler = Box::new(self.0) as Box<_>; -/// Box::pin(async move { Ok(handler) }) -/// } -/// } -/// -/// # let database_url = get_opts(); -/// -/// let opts = OptsBuilder::from_opts(database_url) -/// .local_infile_handler(Some(ExampleHandler(b"foobar"))); +type BoxFuture<'a, T> = futures_core::future::BoxFuture<'a, Result>; + +/// LOCAL INFILE data is a stream of `std::io::Result`. /// -/// let pool = mysql_async::Pool::new(opts); +/// The driver will send this data to the server in response to a LOCAL INFILE request. +pub type InfileData = BoxStream<'static, std::io::Result>; + +/// Global, `Opts`-level `LOCAL INFILE` handler (see ["LOCAL INFILE Handlers"][2] section +/// of the `README.md`). /// -/// let mut conn = pool.get_conn().await?; -/// conn.query_drop("CREATE TEMPORARY TABLE tmp (a TEXT);").await?; -/// match conn.query_drop("LOAD DATA LOCAL INFILE 'baz' INTO TABLE tmp;").await { -/// Ok(()) => (), -/// Err(Error::Server(ref err)) if err.code == 1148 => { -/// // The used command is not allowed with this MySQL version -/// return Ok(()); -/// }, -/// Err(Error::Server(ref err)) if err.code == 3948 => { -/// // Loading local data is disabled; -/// // this must be enabled on both the client and server sides -/// return Ok(()); -/// } -/// e @ Err(_) => e.unwrap(), -/// }; -/// let result: Vec = conn.exec("SELECT * FROM tmp", ()).await?; +/// **Warning:** You should be aware of [Security Considerations for LOAD DATA LOCAL][1]. /// -/// assert_eq!(result.len(), 1); -/// assert_eq!(result[0], "foobar"); +/// The purpose of the handler is to emit infile data in response to a file name. +/// This handler will be called if there is no [`LocalHandler`] installed for the connection. /// -/// drop(conn); // dropped connection will go to the pool +/// The library will call this handler in response to a LOCAL INFILE request from the server. +/// The server, in its turn, will emit LOCAL INFILE requests in response to a `LOAD DATA LOCAL` +/// queries: /// -/// pool.disconnect().await?; -/// # Ok(()) -/// # } +/// ```sql +/// LOAD DATA LOCAL INFILE '' INTO TABLE ; /// ``` /// -/// [1]: https://dev.mysql.com/doc/refman/8.0/en/load-data-local.html -pub trait LocalInfileHandler: Sync + Send { - /// `file_name` is the file name in `LOAD DATA LOCAL INFILE '' INTO TABLE ...;` - /// query. - fn handle(&self, file_name: &[u8]) -> InfileHandlerFuture; +/// [1]: https://dev.mysql.com/doc/refman/8.0/en/load-data-local-security.html +/// [2]: ../#local-infile-handlers +pub trait GlobalHandler: Send + Sync + 'static { + fn handle(&self, file_name: &[u8]) -> BoxFuture<'static, InfileData>; +} + +impl GlobalHandler for T +where + T: for<'a> Fn(&'a [u8]) -> BoxFuture<'static, InfileData>, + T: Send + Sync + 'static, +{ + fn handle(&self, file_name: &[u8]) -> BoxFuture<'static, InfileData> { + (self)(file_name) + } } -pub type InfileHandlerFuture = Pin< - Box< - dyn Future, error::Error>> - + Send - + 'static, - >, ->; +static HANDLER_ID: AtomicUsize = AtomicUsize::new(0); -/// Object used to wrap `T: LocalInfileHandler` inside of Opts. #[derive(Clone)] -pub struct LocalInfileHandlerObject(Uuid, Arc); +pub struct GlobalHandlerObject(usize, Arc); -impl LocalInfileHandlerObject { - pub fn new(handler: T) -> Self { - LocalInfileHandlerObject(Uuid::new_v4(), Arc::new(handler)) +impl GlobalHandlerObject { + pub(crate) fn new(handler: T) -> Self { + Self(HANDLER_ID.fetch_add(1, Ordering::SeqCst), Arc::new(handler)) } - pub fn clone_inner(&self) -> Arc { + pub(crate) fn clone_inner(&self) -> Arc { self.1.clone() } } -impl PartialEq for LocalInfileHandlerObject { - fn eq(&self, other: &LocalInfileHandlerObject) -> bool { - self.0.eq(&other.0) +impl PartialEq for GlobalHandlerObject { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 } } -impl Eq for LocalInfileHandlerObject {} +impl Eq for GlobalHandlerObject {} -impl fmt::Debug for LocalInfileHandlerObject { +impl fmt::Debug for GlobalHandlerObject { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Local infile handler object") + f.debug_tuple("GlobalHandlerObject").field(&"..").finish() } } diff --git a/src/opts.rs b/src/opts.rs index e2c4de4f..2e98c41f 100644 --- a/src/opts.rs +++ b/src/opts.rs @@ -23,7 +23,7 @@ use std::{ use crate::{ consts::CapabilityFlags, error::*, - local_infile_handler::{LocalInfileHandler, LocalInfileHandlerObject}, + local_infile_handler::{GlobalHandler, GlobalHandlerObject}, }; /// Default pool constraints. @@ -329,7 +329,7 @@ pub(crate) struct MysqlOpts { tcp_nodelay: bool, /// Local infile handler - local_infile_handler: Option, + local_infile_handler: Option, /// Connection pool options (defaults to [`PoolOpts::default`]). pool_opts: PoolOpts, @@ -538,7 +538,7 @@ impl Opts { } /// Handler for local infile requests (defaults to `None`). - pub fn local_infile_handler(&self) -> Option> { + pub fn local_infile_handler(&self) -> Option> { self.inner .mysql_opts .local_infile_handler @@ -897,12 +897,12 @@ impl OptsBuilder { self } - /// Defines local infile handler. See [`Opts::local_infile_handler`]. + /// Defines _global_ LOCAL INFILE handler (see crate-level docs). pub fn local_infile_handler(mut self, handler: Option) -> Self where - T: LocalInfileHandler + 'static, + T: GlobalHandler, { - self.opts.local_infile_handler = handler.map(LocalInfileHandlerObject::new); + self.opts.local_infile_handler = handler.map(GlobalHandlerObject::new); self } diff --git a/tests/exports.rs b/tests/exports.rs index 3df07ebc..c8b13137 100644 --- a/tests/exports.rs +++ b/tests/exports.rs @@ -4,12 +4,12 @@ use mysql_async::{ futures::{DisconnectPool, GetConn}, params, prelude::{ - BatchQuery, ConvIr, FromRow, FromValue, LocalInfileHandler, Protocol, Query, Queryable, + BatchQuery, ConvIr, FromRow, FromValue, GlobalHandler, Protocol, Query, Queryable, StatementLike, ToValue, }, BinaryProtocol, Column, Conn, Deserialized, DriverError, Error, FromRowError, FromValueError, IoError, IsolationLevel, Opts, OptsBuilder, Params, ParseError, Pool, PoolConstraints, PoolOpts, QueryResult, Result, Row, Serialized, ServerError, SslOpts, Statement, TextProtocol, - Transaction, TxOpts, UrlError, Value, WhiteListFsLocalInfileHandler, - DEFAULT_INACTIVE_CONNECTION_TTL, DEFAULT_TTL_CHECK_INTERVAL, + Transaction, TxOpts, UrlError, Value, WhiteListFsHandler, DEFAULT_INACTIVE_CONNECTION_TTL, + DEFAULT_TTL_CHECK_INTERVAL, };