Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0"
name = "mysql_async"
readme = "README.md"
repository = "https://github.com/blackbeam/mysql_async"
version = "0.28.1"
version = "0.29.0"
exclude = ["test/*"]
edition = "2018"
categories = ["asynchronous", "database"]
Expand All @@ -20,16 +20,17 @@ futures-core = "0.3"
futures-util = "0.3"
futures-sink = "0.3"
lazy_static = "1"
lru = "0.6.0"
mio = "0.7.7"
mysql_common = { version = "0.27.2", default-features = false }
lru = "0.7.0"
mio = { version = "0.8.0", features = ["os-poll", "net"] }
mysql_common = { version = "0.28.0", default-features = false }
native-tls = "0.2"
once_cell = "1.7.2"
pem = "0.8.1"
pem = "1.0.1"
percent-encoding = "2.1.0"
pin-project = "1.0.2"
serde = "1"
serde_json = "1"
socket2 = "0.4.2"
thiserror = "1.0.4"
tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] }
tokio-util = { version = "0.6.0", features = ["codec"] }
Expand All @@ -47,10 +48,9 @@ rand = "0.8.0"
[features]
default = [
"flate2/zlib",
"mysql_common/bigdecimal",
"mysql_common/chrono",
"mysql_common/bigdecimal03",
"mysql_common/rust_decimal",
"mysql_common/time",
"mysql_common/time03",
"mysql_common/uuid",
"mysql_common/frunk",
]
Expand Down
9 changes: 8 additions & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:

- job: "TestBasicWindows"
pool:
vmImage: "vs2017-win2016"
vmImage: "windows-2019"
strategy:
maxParallel: 10
matrix:
Expand Down Expand Up @@ -136,6 +136,10 @@ jobs:
strategy:
maxParallel: 10
matrix:
v107:
DB_VERSION: "10.7"
v106:
DB_VERSION: "10.6"
v105:
DB_VERSION: "10.5"
v104:
Expand All @@ -156,12 +160,15 @@ jobs:
displayName: Install docker
- bash: |
git clone https://github.com/blackbeam/rust-mysql-simple.git
cd rust-mysql-simple
git checkout 8d745ee
displayName: Clone rust-mysql-simple (for ssl certs)
- bash: |
docker run --rm -d \
--name container \
-v `pwd`:/root \
-p 3307:3306 \
-e MARIADB_ROOT_PASSWORD=password \
-e MYSQL_ROOT_PASSWORD=password \
mariadb:$(DB_VERSION) \
--max-allowed-packet=36700160 \
Expand Down
5 changes: 5 additions & 0 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,14 @@ impl Conn {
async fn handle_handshake(&mut self) -> Result<()> {
let packet = self.read_packet().await?;
let handshake = ParseBuf(&*packet).parse::<HandshakePacket>(())?;

// Handshake scramble is always 21 bytes length (20 + zero terminator)
self.inner.nonce = {
let mut nonce = Vec::from(handshake.scramble_1_ref());
nonce.extend_from_slice(handshake.scramble_2_ref().unwrap_or(&[][..]));
// Trim zero terminator. Fill with zeroes if nonce
// is somehow smaller than 20 bytes (this matches the server behavior).
nonce.resize(20, 0);
nonce
};

Expand Down
128 changes: 32 additions & 96 deletions src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ pub use self::{read_packet::ReadPacket, write_packet::WritePacket};

use bytes::BytesMut;
use futures_core::{ready, stream};
use futures_util::stream::{FuturesUnordered, StreamExt};
use mio::net::{TcpKeepalive, TcpSocket};
use mysql_common::proto::codec::PacketCodec as PacketCodecInner;
use native_tls::{Certificate, Identity, TlsConnector};
use pin_project::pin_project;
use socket2::{Socket as Socket2Socket, TcpKeepalive};
#[cfg(unix)]
use tokio::io::AsyncWriteExt;
use tokio::{
Expand All @@ -35,14 +34,17 @@ use std::{
Read,
},
mem::replace,
net::{SocketAddr, ToSocketAddrs},
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
time::Duration,
};

use crate::{buffer_pool::PooledBuf, error::IoError, opts::SslOpts};
use crate::{
buffer_pool::PooledBuf,
error::IoError,
opts::{HostPortOrUrl, SslOpts, DEFAULT_PORT},
};

#[cfg(unix)]
use crate::io::socket::Socket;
Expand Down Expand Up @@ -208,6 +210,7 @@ impl Endpoint {
.map(|x| vec![x])
.or_else(|_| {
pem::parse_many(&*root_cert_data)
.unwrap_or_default()
.iter()
.map(pem::encode)
.map(|s| Certificate::from_pem(s.as_bytes()))
Expand Down Expand Up @@ -354,108 +357,41 @@ impl Stream {
}
}

pub(crate) async fn connect_tcp<S>(addr: S, keepalive: Option<Duration>) -> io::Result<Stream>
where
S: ToSocketAddrs,
{
// TODO: Use tokio to setup keepalive (see tokio-rs/tokio#3082)
async fn connect_stream(
addr: SocketAddr,
keepalive_opts: Option<TcpKeepalive>,
) -> io::Result<TcpStream> {
let socket = if addr.is_ipv6() {
TcpSocket::new_v6()?
} else {
TcpSocket::new_v4()?
};

if let Some(keepalive_opts) = keepalive_opts {
socket.set_keepalive_params(keepalive_opts)?;
pub(crate) async fn connect_tcp(
addr: &HostPortOrUrl,
keepalive: Option<Duration>,
) -> io::Result<Stream> {
let tcp_stream = match addr {
HostPortOrUrl::HostPort(host, port) => {
TcpStream::connect((host.as_str(), *port)).await?
}
HostPortOrUrl::Url(url) => {
let addrs = url.socket_addrs(|| Some(DEFAULT_PORT))?;
TcpStream::connect(&*addrs).await?
}
};

let stream = tokio::task::spawn_blocking(move || {
let mut stream = socket.connect(addr)?;
let mut poll = mio::Poll::new()?;
let mut events = mio::Events::with_capacity(1024);

poll.registry()
.register(&mut stream, mio::Token(0), mio::Interest::WRITABLE)?;

loop {
poll.poll(&mut events, None)?;

for event in &events {
if event.token() == mio::Token(0) && event.is_error() {
return Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
"Connection refused",
));
}

if event.token() == mio::Token(0) && event.is_writable() {
// The socket connected (probably, it could still be a spurious
// wakeup)
return Ok::<_, io::Error>(stream);
}
}
}
})
.await??;

if let Some(duration) = keepalive {
#[cfg(unix)]
let std_stream = unsafe {
let socket = unsafe {
use std::os::unix::prelude::*;
let fd = stream.into_raw_fd();
std::net::TcpStream::from_raw_fd(fd)
let fd = tcp_stream.as_raw_fd();
Socket2Socket::from_raw_fd(fd)
};

#[cfg(windows)]
let std_stream = unsafe {
let socket = unsafe {
use std::os::windows::prelude::*;
let fd = stream.into_raw_socket();
std::net::TcpStream::from_raw_socket(fd)
let sock = tcp_stream.as_raw_socket();
Socket2Socket::from_raw_socket(sock)
};

Ok(TcpStream::from_std(std_stream)?)
socket.set_tcp_keepalive(&TcpKeepalive::new().with_time(duration))?;
std::mem::forget(socket);
}

let keepalive_opts = keepalive.map(|time| TcpKeepalive::new().with_time(time));

match addr.to_socket_addrs() {
Ok(addresses) => {
let mut streams = FuturesUnordered::new();

for address in addresses {
streams.push(connect_stream(address, keepalive_opts.clone()));
}

let mut err = None;
while let Some(stream) = streams.next().await {
match stream {
Err(e) => {
err = Some(e);
}
Ok(stream) => {
return Ok(Stream {
closed: false,
codec: Box::new(Framed::new(stream.into(), PacketCodec::default()))
.into(),
});
}
}
}

if let Some(e) = err {
Err(e)
} else {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any address",
))
}
}
Err(err) => Err(err),
}
Ok(Stream {
closed: false,
codec: Box::new(Framed::new(tcp_stream.into(), PacketCodec::default())).into(),
})
}

#[cfg(unix)]
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
#[cfg(feature = "nightly")]
extern crate test;

pub use mysql_common::{chrono, constants as consts, params, time, uuid};
pub use mysql_common::{constants as consts, params};

use std::sync::Arc;

Expand Down
18 changes: 2 additions & 16 deletions src/opts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ use url::{Host, Url};
use std::{
borrow::Cow,
convert::TryFrom,
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs},
net::{Ipv4Addr, Ipv6Addr},
path::Path,
str::FromStr,
sync::Arc,
Expand All @@ -40,7 +39,7 @@ const_assert!(
pub const DEFAULT_STMT_CACHE_SIZE: usize = 32;

/// Default server port.
const DEFAULT_PORT: u16 = 3306;
pub const DEFAULT_PORT: u16 = 3306;

/// Default `inactive_connection_ttl` of a pool.
///
Expand All @@ -67,19 +66,6 @@ impl Default for HostPortOrUrl {
}
}

impl ToSocketAddrs for HostPortOrUrl {
type Iter = vec::IntoIter<SocketAddr>;

fn to_socket_addrs(&self) -> io::Result<vec::IntoIter<SocketAddr>> {
let res = match self {
Self::Url(url) => url.socket_addrs(|| Some(DEFAULT_PORT))?.into_iter(),
Self::HostPort(host, port) => (host.as_ref(), *port).to_socket_addrs()?,
};

Ok(res)
}
}

impl HostPortOrUrl {
pub fn get_ip_or_hostname(&self) -> &str {
match self {
Expand Down
10 changes: 5 additions & 5 deletions tests/exports.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#[allow(unused_imports)]
use mysql_async::{
chrono, consts, from_row, from_row_opt, from_value, from_value_opt,
consts, from_row, from_row_opt, from_value, from_value_opt,
futures::{DisconnectPool, GetConn},
params,
prelude::{
BatchQuery, ConvIr, FromRow, FromValue, LocalInfileHandler, Protocol, Query, Queryable,
StatementLike, ToValue,
},
time, uuid, BinaryProtocol, Column, Conn, Deserialized, DriverError, Error, FromRowError,
FromValueError, IoError, IsolationLevel, Opts, OptsBuilder, Params, ParseError, Pool,
PoolConstraints, PoolOpts, QueryResult, Result, Row, Serialized, ServerError, SslOpts,
Statement, TextProtocol, Transaction, TxOpts, UrlError, Value, WhiteListFsLocalInfileHandler,
BinaryProtocol, Column, Conn, Deserialized, DriverError, Error, FromRowError, FromValueError,
IoError, IsolationLevel, Opts, OptsBuilder, Params, ParseError, Pool, PoolConstraints,
PoolOpts, QueryResult, Result, Row, Serialized, ServerError, SslOpts, Statement, TextProtocol,
Transaction, TxOpts, UrlError, Value, WhiteListFsLocalInfileHandler,
DEFAULT_INACTIVE_CONNECTION_TTL, DEFAULT_TTL_CHECK_INTERVAL,
};