From bc9ae7103edbca9d0ab8ddf6b64a14bd796df096 Mon Sep 17 00:00:00 2001 From: Nick Cameron Date: Thu, 23 Jul 2020 17:23:48 +1200 Subject: [PATCH] Refactor pd clients Signed-off-by: Nick Cameron --- Cargo.lock | 1 + src/pd/retry.rs | 209 +++++++++++++++--------------- tikv-client-pd/Cargo.toml | 13 +- tikv-client-pd/src/cluster.rs | 235 ++++++++++++++++++++-------------- 4 files changed, 251 insertions(+), 207 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cc2f9f5d..b47a2a92 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1820,6 +1820,7 @@ dependencies = [ name = "tikv-client-pd" version = "0.0.0" dependencies = [ + "async-trait", "clap", "derive-new", "fail", diff --git a/src/pd/retry.rs b/src/pd/retry.rs index c8ae2896..bf582c5b 100644 --- a/src/pd/retry.rs +++ b/src/pd/retry.rs @@ -3,7 +3,6 @@ //! A utility module for managing and retrying PD requests. use async_trait::async_trait; -use futures::prelude::*; use futures_timer::Delay; use grpcio::Environment; use kvproto::metapb; @@ -47,6 +46,31 @@ impl RetryClient { } } +macro_rules! retry { + ($self: ident, |$cluster: ident| $call: expr) => {{ + let mut last_err = Ok(()); + for _ in 0..LEADER_CHANGE_RETRY { + let $cluster = &$self.cluster.read().await.0; + match $call.await { + Ok(r) => return Ok(r), + Err(e) => last_err = Err(e), + } + + let mut reconnect_count = MAX_REQUEST_COUNT; + while let Err(e) = $self.reconnect(RECONNECT_INTERVAL_SEC).await { + reconnect_count -= 1; + if reconnect_count == 0 { + return Err(e); + } + Delay::new(Duration::from_secs(RECONNECT_INTERVAL_SEC)).await; + } + } + + last_err?; + unreachable!(); + }}; +} + impl RetryClient { pub async fn connect( env: Arc, @@ -68,31 +92,25 @@ impl RetryClient { // These get_* functions will try multiple times to make a request, reconnecting as necessary. pub async fn get_region(self: Arc, key: Vec) -> Result { - let timeout = self.timeout; - retry_request(self, move |cluster| { - cluster.get_region(key.clone(), timeout) - }) - .await + retry!(self, |cluster| cluster + .get_region(key.clone(), self.timeout)) } pub async fn get_region_by_id(self: Arc, id: RegionId) -> Result { - let timeout = self.timeout; - retry_request(self, move |cluster| cluster.get_region_by_id(id, timeout)).await + retry!(self, |cluster| cluster.get_region_by_id(id, self.timeout)) } pub async fn get_store(self: Arc, id: StoreId) -> Result { - let timeout = self.timeout; - retry_request(self, move |cluster| cluster.get_store(id, timeout)).await + retry!(self, |cluster| cluster.get_store(id, self.timeout)) } #[allow(dead_code)] pub async fn get_all_stores(self: Arc) -> Result> { - let timeout = self.timeout; - retry_request(self, move |cluster| cluster.get_all_stores(timeout)).await + retry!(self, |cluster| cluster.get_all_stores(self.timeout)) } pub async fn get_timestamp(self: Arc) -> Result { - retry_request(self, move |cluster| cluster.get_timestamp()).await + retry!(self, |cluster| cluster.get_timestamp()) } } @@ -109,7 +127,6 @@ impl fmt::Debug for RetryClient { trait Reconnect { type Cl; async fn reconnect(&self, interval_sec: u64) -> Result<()>; - async fn with_cluster T + Send + Sync>(&self, f: F) -> T; } #[async_trait] @@ -129,40 +146,6 @@ impl Reconnect for RetryClient { } Ok(()) } - - async fn with_cluster T + Send + Sync>(&self, f: F) -> T { - f(&self.cluster.read().await.0) - } -} - -async fn retry_request(client: Arc, func: Func) -> Result -where - Rc: Reconnect, - Resp: Send + 'static, - Func: Fn(&Rc::Cl) -> RespFuture + Send + Sync, - RespFuture: Future> + Send + 'static, -{ - let mut last_err = Ok(()); - for _ in 0..LEADER_CHANGE_RETRY { - let fut = client.with_cluster(&func).await; - match fut.await { - Ok(r) => return Ok(r), - Err(e) => last_err = Err(e), - } - - // Reconnect. - let mut reconnect_count = MAX_REQUEST_COUNT; - while let Err(e) = client.reconnect(RECONNECT_INTERVAL_SEC).await { - reconnect_count -= 1; - if reconnect_count == 0 { - return Err(e); - } - Delay::new(Duration::from_secs(RECONNECT_INTERVAL_SEC)).await; - } - } - - last_err?; - unreachable!(); } #[cfg(test)] @@ -176,6 +159,7 @@ mod test { fn test_reconnect() { struct MockClient { reconnect_count: Mutex, + cluster: RwLock<((), Instant)>, } #[async_trait] @@ -187,82 +171,101 @@ mod test { // Not actually unimplemented, we just don't care about the error. Err(Error::unimplemented()) } - - async fn with_cluster T + Send + Sync>(&self, f: F) -> T { - f(&()) - } } - let client = Arc::new(MockClient { - reconnect_count: Mutex::new(0), - }); + async fn retry_err(client: Arc) -> Result<()> { + retry!(client, |_c| ready(Err(internal_err!("whoops")))) + } - fn ready_err(_: &()) -> impl Future> + Send + 'static { - ready(Err(internal_err!("whoops"))) + async fn retry_ok(client: Arc) -> Result<()> { + retry!(client, |_c| ready(Ok::<_, Error>(()))) } - let result = executor::block_on(retry_request(client.clone(), ready_err)); - assert!(result.is_err()); - assert_eq!(*client.reconnect_count.lock().unwrap(), MAX_REQUEST_COUNT); + executor::block_on(async { + let client = Arc::new(MockClient { + reconnect_count: Mutex::new(0), + cluster: RwLock::new(((), Instant::now())), + }); + + assert!(retry_err(client.clone()).await.is_err()); + assert_eq!(*client.reconnect_count.lock().unwrap(), MAX_REQUEST_COUNT); - *client.reconnect_count.lock().unwrap() = 0; - let result = executor::block_on(retry_request(client.clone(), |_| ready(Ok(())))); - assert!(result.is_ok()); - assert_eq!(*client.reconnect_count.lock().unwrap(), 0); + *client.reconnect_count.lock().unwrap() = 0; + assert!(retry_ok(client.clone()).await.is_ok()); + assert_eq!(*client.reconnect_count.lock().unwrap(), 0); + }) } #[test] fn test_retry() { struct MockClient { - retry_count: Mutex, + cluster: RwLock<(Mutex, Instant)>, } #[async_trait] impl Reconnect for MockClient { - type Cl = (); + type Cl = Mutex; async fn reconnect(&self, _: u64) -> Result<()> { Ok(()) } + } - async fn with_cluster T + Send + Sync>(&self, f: F) -> T { - *self.retry_count.lock().unwrap() += 1; - f(&()) - } + async fn retry_max_err( + client: Arc, + max_retries: Arc>, + ) -> Result<()> { + retry!(client, |c| { + let mut c = c.lock().unwrap(); + *c += 1; + + let mut max_retries = max_retries.lock().unwrap(); + *max_retries -= 1; + if *max_retries == 0 { + ready(Ok(())) + } else { + ready(Err(internal_err!("whoops"))) + } + }) } - let client = Arc::new(MockClient { - retry_count: Mutex::new(0), - }); - let max_retries = Arc::new(Mutex::new(1000)); - - let result = executor::block_on(retry_request(client.clone(), |_| { - let mut max_retries = max_retries.lock().unwrap(); - *max_retries -= 1; - if *max_retries == 0 { - ready(Ok(())) - } else { - ready(Err(internal_err!("whoops"))) - } - })); - assert!(result.is_err()); - assert_eq!(*client.retry_count.lock().unwrap(), LEADER_CHANGE_RETRY); - - let client = Arc::new(MockClient { - retry_count: Mutex::new(0), - }); - let max_retries = Arc::new(Mutex::new(2)); - - let result = executor::block_on(retry_request(client.clone(), |_| { - let mut max_retries = max_retries.lock().unwrap(); - *max_retries -= 1; - if *max_retries == 0 { - ready(Ok(())) - } else { - ready(Err(internal_err!("whoops"))) - } - })); - assert!(result.is_ok()); - assert_eq!(*client.retry_count.lock().unwrap(), 2); + async fn retry_max_ok( + client: Arc, + max_retries: Arc>, + ) -> Result<()> { + retry!(client, |c| { + let mut c = c.lock().unwrap(); + *c += 1; + + let mut max_retries = max_retries.lock().unwrap(); + *max_retries -= 1; + if *max_retries == 0 { + ready(Ok(())) + } else { + ready(Err(internal_err!("whoops"))) + } + }) + } + + executor::block_on(async { + let client = Arc::new(MockClient { + cluster: RwLock::new((Mutex::new(0), Instant::now())), + }); + let max_retries = Arc::new(Mutex::new(1000)); + + assert!(retry_max_err(client.clone(), max_retries).await.is_err()); + assert_eq!( + *client.cluster.read().await.0.lock().unwrap(), + LEADER_CHANGE_RETRY + ); + + let client = Arc::new(MockClient { + cluster: RwLock::new((Mutex::new(0), Instant::now())), + }); + let max_retries = Arc::new(Mutex::new(2)); + + assert!(retry_max_ok(client.clone(), max_retries).await.is_ok()); + assert_eq!(*client.cluster.read().await.0.lock().unwrap(), 2); + }) } } diff --git a/tikv-client-pd/Cargo.toml b/tikv-client-pd/Cargo.toml index b8de8867..879b7071 100644 --- a/tikv-client-pd/Cargo.toml +++ b/tikv-client-pd/Cargo.toml @@ -3,21 +3,20 @@ name = "tikv-client-pd" version = "0.0.0" edition = "2018" - [dependencies] +async-trait = "0.1" derive-new = "0.5" -kvproto = { git = "https://github.com/pingcap/kvproto.git", rev = "1e28226154c374788f38d3a542fc505cd74720f3", features = [ "prost-codec" ], default-features = false } futures = { version = "0.3.5", features = ["compat", "async-await", "thread-pool"] } -tokio = { version = "0.2", features = ["sync"] } grpcio = { version = "0.6", features = [ "secure", "prost-codec" ], default-features = false } +kvproto = { git = "https://github.com/pingcap/kvproto.git", rev = "1e28226154c374788f38d3a542fc505cd74720f3", features = [ "prost-codec" ], default-features = false } log = "0.4" - tikv-client-common = { path = "../tikv-client-common" } +tokio = { version = "0.2", features = ["sync"] } [dev-dependencies] clap = "2.32" -tempdir = "0.3" -tokio = { version = "0.2", features = ["rt-threaded", "macros"] } +fail = { version = "0.3", features = [ "failpoints" ] } proptest = "0.9" proptest-derive = "0.1.0" -fail = { version = "0.3", features = [ "failpoints" ] } +tempdir = "0.3" +tokio = { version = "0.2", features = ["rt-threaded", "macros"] } diff --git a/tikv-client-pd/src/cluster.rs b/tikv-client-pd/src/cluster.rs index f3b63a77..2367d9e3 100644 --- a/tikv-client-pd/src/cluster.rs +++ b/tikv-client-pd/src/cluster.rs @@ -4,7 +4,6 @@ #![allow(dead_code)] use crate::timestamp::TimestampOracle; -use futures::prelude::*; use grpcio::{CallOption, Environment}; use kvproto::{metapb, pdpb}; use std::{ @@ -16,6 +15,14 @@ use tikv_client_common::{ security::SecurityManager, stats::pd_stats, Error, Region, RegionId, Result, StoreId, Timestamp, }; +/// A PD cluster. +pub struct Cluster { + id: u64, + client: pdpb::PdClient, + members: pdpb::GetMembersResponse, + tso: TimestampOracle, +} + macro_rules! pd_request { ($cluster_id:expr, $type:ty) => {{ let mut request = <$type>::default(); @@ -26,126 +33,53 @@ macro_rules! pd_request { }}; } -/// A PD cluster. -pub struct Cluster { - pub id: u64, - pub(super) client: pdpb::PdClient, - members: pdpb::GetMembersResponse, - tso: TimestampOracle, -} - // These methods make a single attempt to make a request. impl Cluster { - pub fn get_region( - &self, - key: Vec, - timeout: Duration, - ) -> impl Future> { - let context = pd_stats("get_region"); - let option = CallOption::default().timeout(timeout); + pub async fn get_region(&self, key: Vec, timeout: Duration) -> Result { + let context = pd_stats(pdpb::GetRegionRequest::TAG); let mut req = pd_request!(self.id, pdpb::GetRegionRequest); req.set_region_key(key.clone()); + let resp = req.send(&self.client, timeout).await; - self.client - .get_region_async_opt(&req, option) - .unwrap() - .map(move |r| context.done(r.map_err(|e| e.into()))) - .and_then(move |resp| { - if resp.get_header().has_error() { - return future::ready(Err(internal_err!(resp - .get_header() - .get_error() - .get_message()))); - } - let region = resp - .region - .ok_or_else(|| Error::region_for_key_not_found(key)); - let leader = resp.leader; - future::ready(region.map(move |r| Region::new(r, leader))) - }) + let resp = context.done(resp)?; + region_from_response(resp, || Error::region_for_key_not_found(key)) } - pub fn get_region_by_id( - &self, - id: RegionId, - timeout: Duration, - ) -> impl Future> { - let context = pd_stats("get_region_by_id"); - let option = CallOption::default().timeout(timeout); + pub async fn get_region_by_id(&self, id: RegionId, timeout: Duration) -> Result { + let context = pd_stats(pdpb::GetRegionByIdRequest::TAG); let mut req = pd_request!(self.id, pdpb::GetRegionByIdRequest); req.set_region_id(id); + let resp = req.send(&self.client, timeout).await; - self.client - .get_region_by_id_async_opt(&req, option) - .unwrap() - .map(move |r| context.done(r.map_err(|e| e.into()))) - .and_then(move |resp| { - if resp.get_header().has_error() { - return future::ready(Err(internal_err!(resp - .get_header() - .get_error() - .get_message()))); - } - let region = resp.region.ok_or_else(|| Error::region_not_found(id)); - let leader = resp.leader; - future::ready(region.map(move |r| Region::new(r, leader))) - }) + let resp = context.done(resp)?; + region_from_response(resp, || Error::region_not_found(id)) } - pub fn get_store( - &self, - id: StoreId, - timeout: Duration, - ) -> impl Future> { - let context = pd_stats("get_store"); - let option = CallOption::default().timeout(timeout); + pub async fn get_store(&self, id: StoreId, timeout: Duration) -> Result { + let context = pd_stats(pdpb::GetStoreRequest::TAG); let mut req = pd_request!(self.id, pdpb::GetStoreRequest); req.set_store_id(id); + let resp = req.send(&self.client, timeout).await; - self.client - .get_store_async_opt(&req, option) - .unwrap() - .map(move |r| context.done(r.map_err(|e| e.into()))) - .and_then(|mut resp| { - if resp.get_header().has_error() { - return future::ready(Err(internal_err!(resp - .get_header() - .get_error() - .get_message()))); - } - future::ready(Ok(resp.take_store())) - }) + let mut resp = context.done(resp)?; + Ok(resp.take_store()) } - pub fn get_all_stores( - &self, - timeout: Duration, - ) -> impl Future>> { - let context = pd_stats("get_all_stores"); - let option = CallOption::default().timeout(timeout); + pub async fn get_all_stores(&self, timeout: Duration) -> Result> { + let context = pd_stats(pdpb::GetAllStoresRequest::TAG); let req = pd_request!(self.id, pdpb::GetAllStoresRequest); + let resp = req.send(&self.client, timeout).await; - self.client - .get_all_stores_async_opt(&req, option) - .unwrap() - .map(move |r| context.done(r.map_err(|e| e.into()))) - .and_then(|mut resp| { - if resp.get_header().has_error() { - return future::ready(Err(internal_err!(resp - .get_header() - .get_error() - .get_message()))); - } - future::ready(Ok(resp.take_stores().into_iter().map(Into::into).collect())) - }) + let mut resp = context.done(resp)?; + Ok(resp.take_stores().into_iter().map(Into::into).collect()) } - pub fn get_timestamp(&self) -> impl Future> { - self.tso.clone().get_timestamp() + pub async fn get_timestamp(&self) -> Result { + self.tso.clone().get_timestamp().await } } @@ -214,7 +148,7 @@ impl Connection { Ok(resp) => resp, // Ignore failed PD node. Err(e) => { - error!("PD endpoint {} failed to respond: {:?}", ep, e); + warn!("PD endpoint {} failed to respond: {:?}", ep, e); continue; } }; @@ -337,3 +271,110 @@ impl Connection { Err(internal_err!("failed to connect to {:?}", members)) } } + +type GrpcResult = std::result::Result; + +#[async_trait] +trait PdMessage { + type Response: PdResponse; + const TAG: &'static str; + + async fn rpc(&self, client: &pdpb::PdClient, opt: CallOption) -> GrpcResult; + + async fn send(&self, client: &pdpb::PdClient, timeout: Duration) -> Result { + let option = CallOption::default().timeout(timeout); + let response = self.rpc(client, option).await?; + + if response.header().has_error() { + Err(internal_err!(response.header().get_error().get_message())) + } else { + Ok(response) + } + } +} + +trait PdResponse { + fn header(&self) -> &pdpb::ResponseHeader; +} + +impl PdResponse for pdpb::GetRegionResponse { + fn header(&self) -> &pdpb::ResponseHeader { + self.get_header() + } +} + +impl PdResponse for pdpb::GetStoreResponse { + fn header(&self) -> &pdpb::ResponseHeader { + self.get_header() + } +} + +impl PdResponse for pdpb::GetAllStoresResponse { + fn header(&self) -> &pdpb::ResponseHeader { + self.get_header() + } +} + +#[async_trait] +impl PdMessage for pdpb::GetRegionRequest { + type Response = pdpb::GetRegionResponse; + const TAG: &'static str = "get_region"; + + async fn rpc(&self, client: &pdpb::PdClient, opt: CallOption) -> GrpcResult { + client + .get_region_async_opt(self, opt) + .map(Compat01As03::new) + .unwrap() + .await + } +} + +#[async_trait] +impl PdMessage for pdpb::GetRegionByIdRequest { + type Response = pdpb::GetRegionResponse; + const TAG: &'static str = "get_region_by_id"; + + async fn rpc(&self, client: &pdpb::PdClient, opt: CallOption) -> GrpcResult { + client + .get_region_by_id_async_opt(self, opt) + .map(Compat01As03::new) + .unwrap() + .await + } +} + +#[async_trait] +impl PdMessage for pdpb::GetStoreRequest { + type Response = pdpb::GetStoreResponse; + const TAG: &'static str = "get_store"; + + async fn rpc(&self, client: &pdpb::PdClient, opt: CallOption) -> GrpcResult { + client + .get_store_async_opt(self, opt) + .map(Compat01As03::new) + .unwrap() + .await + } +} + +#[async_trait] +impl PdMessage for pdpb::GetAllStoresRequest { + type Response = pdpb::GetAllStoresResponse; + const TAG: &'static str = "get_all_stores"; + + async fn rpc(&self, client: &pdpb::PdClient, opt: CallOption) -> GrpcResult { + client + .get_all_stores_async_opt(self, opt) + .map(Compat01As03::new) + .unwrap() + .await + } +} + +fn region_from_response( + resp: pdpb::GetRegionResponse, + err: impl FnOnce() -> Error, +) -> Result { + let region = resp.region.ok_or_else(err)?; + Ok(Region::new(region, resp.leader)) +}