Skip to content

Commit

Permalink
Refactor pd clients
Browse files Browse the repository at this point in the history
Signed-off-by: Nick Cameron <nrc@ncameron.org>
  • Loading branch information
nrc committed Jul 26, 2020
1 parent f16df2d commit bc9ae71
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 207 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

209 changes: 106 additions & 103 deletions src/pd/retry.rs
Expand Up @@ -3,7 +3,6 @@
//! A utility module for managing and retrying PD requests.

use async_trait::async_trait;
use futures::prelude::*;
use futures_timer::Delay;
use grpcio::Environment;
use kvproto::metapb;
Expand Down Expand Up @@ -47,6 +46,31 @@ impl<Cl> RetryClient<Cl> {
}
}

macro_rules! retry {
($self: ident, |$cluster: ident| $call: expr) => {{
let mut last_err = Ok(());
for _ in 0..LEADER_CHANGE_RETRY {
let $cluster = &$self.cluster.read().await.0;
match $call.await {
Ok(r) => return Ok(r),
Err(e) => last_err = Err(e),
}

let mut reconnect_count = MAX_REQUEST_COUNT;
while let Err(e) = $self.reconnect(RECONNECT_INTERVAL_SEC).await {
reconnect_count -= 1;
if reconnect_count == 0 {
return Err(e);
}
Delay::new(Duration::from_secs(RECONNECT_INTERVAL_SEC)).await;
}
}

last_err?;
unreachable!();
}};
}

impl RetryClient<Cluster> {
pub async fn connect(
env: Arc<Environment>,
Expand All @@ -68,31 +92,25 @@ impl RetryClient<Cluster> {

// These get_* functions will try multiple times to make a request, reconnecting as necessary.
pub async fn get_region(self: Arc<Self>, key: Vec<u8>) -> Result<Region> {
let timeout = self.timeout;
retry_request(self, move |cluster| {
cluster.get_region(key.clone(), timeout)
})
.await
retry!(self, |cluster| cluster
.get_region(key.clone(), self.timeout))
}

pub async fn get_region_by_id(self: Arc<Self>, id: RegionId) -> Result<Region> {
let timeout = self.timeout;
retry_request(self, move |cluster| cluster.get_region_by_id(id, timeout)).await
retry!(self, |cluster| cluster.get_region_by_id(id, self.timeout))
}

pub async fn get_store(self: Arc<Self>, id: StoreId) -> Result<metapb::Store> {
let timeout = self.timeout;
retry_request(self, move |cluster| cluster.get_store(id, timeout)).await
retry!(self, |cluster| cluster.get_store(id, self.timeout))
}

#[allow(dead_code)]
pub async fn get_all_stores(self: Arc<Self>) -> Result<Vec<metapb::Store>> {
let timeout = self.timeout;
retry_request(self, move |cluster| cluster.get_all_stores(timeout)).await
retry!(self, |cluster| cluster.get_all_stores(self.timeout))
}

pub async fn get_timestamp(self: Arc<Self>) -> Result<Timestamp> {
retry_request(self, move |cluster| cluster.get_timestamp()).await
retry!(self, |cluster| cluster.get_timestamp())
}
}

Expand All @@ -109,7 +127,6 @@ impl fmt::Debug for RetryClient {
trait Reconnect {
type Cl;
async fn reconnect(&self, interval_sec: u64) -> Result<()>;
async fn with_cluster<T, F: Fn(&Self::Cl) -> T + Send + Sync>(&self, f: F) -> T;
}

#[async_trait]
Expand All @@ -129,40 +146,6 @@ impl Reconnect for RetryClient<Cluster> {
}
Ok(())
}

async fn with_cluster<T, F: Fn(&Cluster) -> T + Send + Sync>(&self, f: F) -> T {
f(&self.cluster.read().await.0)
}
}

async fn retry_request<Rc, Resp, Func, RespFuture>(client: Arc<Rc>, func: Func) -> Result<Resp>
where
Rc: Reconnect,
Resp: Send + 'static,
Func: Fn(&Rc::Cl) -> RespFuture + Send + Sync,
RespFuture: Future<Output = Result<Resp>> + Send + 'static,
{
let mut last_err = Ok(());
for _ in 0..LEADER_CHANGE_RETRY {
let fut = client.with_cluster(&func).await;
match fut.await {
Ok(r) => return Ok(r),
Err(e) => last_err = Err(e),
}

// Reconnect.
let mut reconnect_count = MAX_REQUEST_COUNT;
while let Err(e) = client.reconnect(RECONNECT_INTERVAL_SEC).await {
reconnect_count -= 1;
if reconnect_count == 0 {
return Err(e);
}
Delay::new(Duration::from_secs(RECONNECT_INTERVAL_SEC)).await;
}
}

last_err?;
unreachable!();
}

#[cfg(test)]
Expand All @@ -176,6 +159,7 @@ mod test {
fn test_reconnect() {
struct MockClient {
reconnect_count: Mutex<usize>,
cluster: RwLock<((), Instant)>,
}

#[async_trait]
Expand All @@ -187,82 +171,101 @@ mod test {
// Not actually unimplemented, we just don't care about the error.
Err(Error::unimplemented())
}

async fn with_cluster<T, F: Fn(&Self::Cl) -> T + Send + Sync>(&self, f: F) -> T {
f(&())
}
}

let client = Arc::new(MockClient {
reconnect_count: Mutex::new(0),
});
async fn retry_err(client: Arc<MockClient>) -> Result<()> {
retry!(client, |_c| ready(Err(internal_err!("whoops"))))
}

fn ready_err(_: &()) -> impl Future<Output = Result<()>> + Send + 'static {
ready(Err(internal_err!("whoops")))
async fn retry_ok(client: Arc<MockClient>) -> Result<()> {
retry!(client, |_c| ready(Ok::<_, Error>(())))
}

let result = executor::block_on(retry_request(client.clone(), ready_err));
assert!(result.is_err());
assert_eq!(*client.reconnect_count.lock().unwrap(), MAX_REQUEST_COUNT);
executor::block_on(async {
let client = Arc::new(MockClient {
reconnect_count: Mutex::new(0),
cluster: RwLock::new(((), Instant::now())),
});

assert!(retry_err(client.clone()).await.is_err());
assert_eq!(*client.reconnect_count.lock().unwrap(), MAX_REQUEST_COUNT);

*client.reconnect_count.lock().unwrap() = 0;
let result = executor::block_on(retry_request(client.clone(), |_| ready(Ok(()))));
assert!(result.is_ok());
assert_eq!(*client.reconnect_count.lock().unwrap(), 0);
*client.reconnect_count.lock().unwrap() = 0;
assert!(retry_ok(client.clone()).await.is_ok());
assert_eq!(*client.reconnect_count.lock().unwrap(), 0);
})
}

#[test]
fn test_retry() {
struct MockClient {
retry_count: Mutex<usize>,
cluster: RwLock<(Mutex<usize>, Instant)>,
}

#[async_trait]
impl Reconnect for MockClient {
type Cl = ();
type Cl = Mutex<usize>;

async fn reconnect(&self, _: u64) -> Result<()> {
Ok(())
}
}

async fn with_cluster<T, F: Fn(&Self::Cl) -> T + Send + Sync>(&self, f: F) -> T {
*self.retry_count.lock().unwrap() += 1;
f(&())
}
async fn retry_max_err(
client: Arc<MockClient>,
max_retries: Arc<Mutex<usize>>,
) -> Result<()> {
retry!(client, |c| {
let mut c = c.lock().unwrap();
*c += 1;

let mut max_retries = max_retries.lock().unwrap();
*max_retries -= 1;
if *max_retries == 0 {
ready(Ok(()))
} else {
ready(Err(internal_err!("whoops")))
}
})
}

let client = Arc::new(MockClient {
retry_count: Mutex::new(0),
});
let max_retries = Arc::new(Mutex::new(1000));

let result = executor::block_on(retry_request(client.clone(), |_| {
let mut max_retries = max_retries.lock().unwrap();
*max_retries -= 1;
if *max_retries == 0 {
ready(Ok(()))
} else {
ready(Err(internal_err!("whoops")))
}
}));
assert!(result.is_err());
assert_eq!(*client.retry_count.lock().unwrap(), LEADER_CHANGE_RETRY);

let client = Arc::new(MockClient {
retry_count: Mutex::new(0),
});
let max_retries = Arc::new(Mutex::new(2));

let result = executor::block_on(retry_request(client.clone(), |_| {
let mut max_retries = max_retries.lock().unwrap();
*max_retries -= 1;
if *max_retries == 0 {
ready(Ok(()))
} else {
ready(Err(internal_err!("whoops")))
}
}));
assert!(result.is_ok());
assert_eq!(*client.retry_count.lock().unwrap(), 2);
async fn retry_max_ok(
client: Arc<MockClient>,
max_retries: Arc<Mutex<usize>>,
) -> Result<()> {
retry!(client, |c| {
let mut c = c.lock().unwrap();
*c += 1;

let mut max_retries = max_retries.lock().unwrap();
*max_retries -= 1;
if *max_retries == 0 {
ready(Ok(()))
} else {
ready(Err(internal_err!("whoops")))
}
})
}

executor::block_on(async {
let client = Arc::new(MockClient {
cluster: RwLock::new((Mutex::new(0), Instant::now())),
});
let max_retries = Arc::new(Mutex::new(1000));

assert!(retry_max_err(client.clone(), max_retries).await.is_err());
assert_eq!(
*client.cluster.read().await.0.lock().unwrap(),
LEADER_CHANGE_RETRY
);

let client = Arc::new(MockClient {
cluster: RwLock::new((Mutex::new(0), Instant::now())),
});
let max_retries = Arc::new(Mutex::new(2));

assert!(retry_max_ok(client.clone(), max_retries).await.is_ok());
assert_eq!(*client.cluster.read().await.0.lock().unwrap(), 2);
})
}
}
13 changes: 6 additions & 7 deletions tikv-client-pd/Cargo.toml
Expand Up @@ -3,21 +3,20 @@ name = "tikv-client-pd"
version = "0.0.0"
edition = "2018"


[dependencies]
async-trait = "0.1"
derive-new = "0.5"
kvproto = { git = "https://github.com/pingcap/kvproto.git", rev = "1e28226154c374788f38d3a542fc505cd74720f3", features = [ "prost-codec" ], default-features = false }
futures = { version = "0.3.5", features = ["compat", "async-await", "thread-pool"] }
tokio = { version = "0.2", features = ["sync"] }
grpcio = { version = "0.6", features = [ "secure", "prost-codec" ], default-features = false }
kvproto = { git = "https://github.com/pingcap/kvproto.git", rev = "1e28226154c374788f38d3a542fc505cd74720f3", features = [ "prost-codec" ], default-features = false }
log = "0.4"

tikv-client-common = { path = "../tikv-client-common" }
tokio = { version = "0.2", features = ["sync"] }

[dev-dependencies]
clap = "2.32"
tempdir = "0.3"
tokio = { version = "0.2", features = ["rt-threaded", "macros"] }
fail = { version = "0.3", features = [ "failpoints" ] }
proptest = "0.9"
proptest-derive = "0.1.0"
fail = { version = "0.3", features = [ "failpoints" ] }
tempdir = "0.3"
tokio = { version = "0.2", features = ["rt-threaded", "macros"] }

0 comments on commit bc9ae71

Please sign in to comment.