diff --git a/Cargo.toml b/Cargo.toml index 87702800..54d6390b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.28.1" +version = "0.29.0" exclude = ["test/*"] edition = "2018" categories = ["asynchronous", "database"] @@ -20,16 +20,17 @@ futures-core = "0.3" futures-util = "0.3" futures-sink = "0.3" lazy_static = "1" -lru = "0.6.0" -mio = "0.7.7" -mysql_common = { version = "0.27.2", default-features = false } +lru = "0.7.0" +mio = { version = "0.8.0", features = ["os-poll", "net"] } +mysql_common = { version = "0.28.0", default-features = false } native-tls = "0.2" once_cell = "1.7.2" -pem = "0.8.1" +pem = "1.0.1" percent-encoding = "2.1.0" pin-project = "1.0.2" serde = "1" serde_json = "1" +socket2 = "0.4.2" thiserror = "1.0.4" tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] } tokio-util = { version = "0.6.0", features = ["codec"] } @@ -47,10 +48,9 @@ rand = "0.8.0" [features] default = [ "flate2/zlib", - "mysql_common/bigdecimal", - "mysql_common/chrono", + "mysql_common/bigdecimal03", "mysql_common/rust_decimal", - "mysql_common/time", + "mysql_common/time03", "mysql_common/uuid", "mysql_common/frunk", ] diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 62f453ed..fe273551 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -51,7 +51,7 @@ jobs: - job: "TestBasicWindows" pool: - vmImage: "vs2017-win2016" + vmImage: "windows-2019" strategy: maxParallel: 10 matrix: @@ -136,6 +136,10 @@ jobs: strategy: maxParallel: 10 matrix: + v107: + DB_VERSION: "10.7" + v106: + DB_VERSION: "10.6" v105: DB_VERSION: "10.5" v104: @@ -156,12 +160,15 @@ jobs: displayName: Install docker - bash: | git clone https://github.com/blackbeam/rust-mysql-simple.git + cd rust-mysql-simple + git checkout 8d745ee displayName: Clone rust-mysql-simple (for ssl certs) - bash: | docker run --rm -d \ --name container \ -v `pwd`:/root \ -p 3307:3306 \ + -e MARIADB_ROOT_PASSWORD=password \ -e MYSQL_ROOT_PASSWORD=password \ mariadb:$(DB_VERSION) \ --max-allowed-packet=36700160 \ diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 3b828862..872a34f4 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -429,9 +429,14 @@ impl Conn { async fn handle_handshake(&mut self) -> Result<()> { let packet = self.read_packet().await?; let handshake = ParseBuf(&*packet).parse::(())?; + + // Handshake scramble is always 21 bytes length (20 + zero terminator) self.inner.nonce = { let mut nonce = Vec::from(handshake.scramble_1_ref()); nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..])); + // Trim zero terminator. Fill with zeroes if nonce + // is somehow smaller than 20 bytes (this matches the server behavior). + nonce.resize(20, 0); nonce }; diff --git a/src/io/mod.rs b/src/io/mod.rs index 11e22be0..33aac3c2 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -10,11 +10,10 @@ pub use self::{read_packet::ReadPacket, write_packet::WritePacket}; use bytes::BytesMut; use futures_core::{ready, stream}; -use futures_util::stream::{FuturesUnordered, StreamExt}; -use mio::net::{TcpKeepalive, TcpSocket}; use mysql_common::proto::codec::PacketCodec as PacketCodecInner; use native_tls::{Certificate, Identity, TlsConnector}; use pin_project::pin_project; +use socket2::{Socket as Socket2Socket, TcpKeepalive}; #[cfg(unix)] use tokio::io::AsyncWriteExt; use tokio::{ @@ -35,14 +34,17 @@ use std::{ Read, }, mem::replace, - net::{SocketAddr, ToSocketAddrs}, ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, time::Duration, }; -use crate::{buffer_pool::PooledBuf, error::IoError, opts::SslOpts}; +use crate::{ + buffer_pool::PooledBuf, + error::IoError, + opts::{HostPortOrUrl, SslOpts, DEFAULT_PORT}, +}; #[cfg(unix)] use crate::io::socket::Socket; @@ -208,6 +210,7 @@ impl Endpoint { .map(|x| vec![x]) .or_else(|_| { pem::parse_many(&*root_cert_data) + .unwrap_or_default() .iter() .map(pem::encode) .map(|s| Certificate::from_pem(s.as_bytes())) @@ -354,108 +357,41 @@ impl Stream { } } - pub(crate) async fn connect_tcp(addr: S, keepalive: Option) -> io::Result - where - S: ToSocketAddrs, - { - // TODO: Use tokio to setup keepalive (see tokio-rs/tokio#3082) - async fn connect_stream( - addr: SocketAddr, - keepalive_opts: Option, - ) -> io::Result { - let socket = if addr.is_ipv6() { - TcpSocket::new_v6()? - } else { - TcpSocket::new_v4()? - }; - - if let Some(keepalive_opts) = keepalive_opts { - socket.set_keepalive_params(keepalive_opts)?; + pub(crate) async fn connect_tcp( + addr: &HostPortOrUrl, + keepalive: Option, + ) -> io::Result { + let tcp_stream = match addr { + HostPortOrUrl::HostPort(host, port) => { + TcpStream::connect((host.as_str(), *port)).await? } + HostPortOrUrl::Url(url) => { + let addrs = url.socket_addrs(|| Some(DEFAULT_PORT))?; + TcpStream::connect(&*addrs).await? + } + }; - let stream = tokio::task::spawn_blocking(move || { - let mut stream = socket.connect(addr)?; - let mut poll = mio::Poll::new()?; - let mut events = mio::Events::with_capacity(1024); - - poll.registry() - .register(&mut stream, mio::Token(0), mio::Interest::WRITABLE)?; - - loop { - poll.poll(&mut events, None)?; - - for event in &events { - if event.token() == mio::Token(0) && event.is_error() { - return Err(io::Error::new( - io::ErrorKind::ConnectionRefused, - "Connection refused", - )); - } - - if event.token() == mio::Token(0) && event.is_writable() { - // The socket connected (probably, it could still be a spurious - // wakeup) - return Ok::<_, io::Error>(stream); - } - } - } - }) - .await??; - + if let Some(duration) = keepalive { #[cfg(unix)] - let std_stream = unsafe { + let socket = unsafe { use std::os::unix::prelude::*; - let fd = stream.into_raw_fd(); - std::net::TcpStream::from_raw_fd(fd) + let fd = tcp_stream.as_raw_fd(); + Socket2Socket::from_raw_fd(fd) }; - #[cfg(windows)] - let std_stream = unsafe { + let socket = unsafe { use std::os::windows::prelude::*; - let fd = stream.into_raw_socket(); - std::net::TcpStream::from_raw_socket(fd) + let sock = tcp_stream.as_raw_socket(); + Socket2Socket::from_raw_socket(sock) }; - - Ok(TcpStream::from_std(std_stream)?) + socket.set_tcp_keepalive(&TcpKeepalive::new().with_time(duration))?; + std::mem::forget(socket); } - let keepalive_opts = keepalive.map(|time| TcpKeepalive::new().with_time(time)); - - match addr.to_socket_addrs() { - Ok(addresses) => { - let mut streams = FuturesUnordered::new(); - - for address in addresses { - streams.push(connect_stream(address, keepalive_opts.clone())); - } - - let mut err = None; - while let Some(stream) = streams.next().await { - match stream { - Err(e) => { - err = Some(e); - } - Ok(stream) => { - return Ok(Stream { - closed: false, - codec: Box::new(Framed::new(stream.into(), PacketCodec::default())) - .into(), - }); - } - } - } - - if let Some(e) = err { - Err(e) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve to any address", - )) - } - } - Err(err) => Err(err), - } + Ok(Stream { + closed: false, + codec: Box::new(Framed::new(tcp_stream.into(), PacketCodec::default())).into(), + }) } #[cfg(unix)] diff --git a/src/lib.rs b/src/lib.rs index 545f62f4..f71d3b7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,7 +91,7 @@ #[cfg(feature = "nightly")] extern crate test; -pub use mysql_common::{chrono, constants as consts, params, time, uuid}; +pub use mysql_common::{constants as consts, params}; use std::sync::Arc; diff --git a/src/opts.rs b/src/opts.rs index 2f2e2c8d..e2c4de4f 100644 --- a/src/opts.rs +++ b/src/opts.rs @@ -12,8 +12,7 @@ use url::{Host, Url}; use std::{ borrow::Cow, convert::TryFrom, - io, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs}, + net::{Ipv4Addr, Ipv6Addr}, path::Path, str::FromStr, sync::Arc, @@ -40,7 +39,7 @@ const_assert!( pub const DEFAULT_STMT_CACHE_SIZE: usize = 32; /// Default server port. -const DEFAULT_PORT: u16 = 3306; +pub const DEFAULT_PORT: u16 = 3306; /// Default `inactive_connection_ttl` of a pool. /// @@ -67,19 +66,6 @@ impl Default for HostPortOrUrl { } } -impl ToSocketAddrs for HostPortOrUrl { - type Iter = vec::IntoIter; - - fn to_socket_addrs(&self) -> io::Result> { - let res = match self { - Self::Url(url) => url.socket_addrs(|| Some(DEFAULT_PORT))?.into_iter(), - Self::HostPort(host, port) => (host.as_ref(), *port).to_socket_addrs()?, - }; - - Ok(res) - } -} - impl HostPortOrUrl { pub fn get_ip_or_hostname(&self) -> &str { match self { diff --git a/tests/exports.rs b/tests/exports.rs index fa75224f..3df07ebc 100644 --- a/tests/exports.rs +++ b/tests/exports.rs @@ -1,15 +1,15 @@ #[allow(unused_imports)] use mysql_async::{ - chrono, consts, from_row, from_row_opt, from_value, from_value_opt, + consts, from_row, from_row_opt, from_value, from_value_opt, futures::{DisconnectPool, GetConn}, params, prelude::{ BatchQuery, ConvIr, FromRow, FromValue, LocalInfileHandler, Protocol, Query, Queryable, StatementLike, ToValue, }, - time, uuid, BinaryProtocol, Column, Conn, Deserialized, DriverError, Error, FromRowError, - FromValueError, IoError, IsolationLevel, Opts, OptsBuilder, Params, ParseError, Pool, - PoolConstraints, PoolOpts, QueryResult, Result, Row, Serialized, ServerError, SslOpts, - Statement, TextProtocol, Transaction, TxOpts, UrlError, Value, WhiteListFsLocalInfileHandler, + BinaryProtocol, Column, Conn, Deserialized, DriverError, Error, FromRowError, FromValueError, + IoError, IsolationLevel, Opts, OptsBuilder, Params, ParseError, Pool, PoolConstraints, + PoolOpts, QueryResult, Result, Row, Serialized, ServerError, SslOpts, Statement, TextProtocol, + Transaction, TxOpts, UrlError, Value, WhiteListFsLocalInfileHandler, DEFAULT_INACTIVE_CONNECTION_TTL, DEFAULT_TTL_CHECK_INTERVAL, };