From a8bfdee6105cde19d2a5c4c4b03608319b173925 Mon Sep 17 00:00:00 2001 From: Dmitry Tantsur Date: Sun, 22 Dec 2019 16:47:42 +0100 Subject: [PATCH] feat: add paginated query support --- Cargo.toml | 5 +- examples/list-servers-paginated.rs | 71 ++++++++++++++++++ src/adapter.rs | 104 ++++++++++++++++++++++++++ src/lib.rs | 2 + src/session.rs | 115 +++++++++++++++++++++++++++++ src/stream.rs | 101 +++++++++++++++++++++++++ 6 files changed, 397 insertions(+), 1 deletion(-) create mode 100644 examples/list-servers-paginated.rs create mode 100644 src/stream.rs diff --git a/Cargo.toml b/Cargo.toml index 32e467c..42dba94 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,15 +15,18 @@ edition = "2018" [features] -default = ["native-tls"] +default = ["native-tls", "stream"] native-tls = ["reqwest/default-tls"] rustls = ["reqwest/rustls-tls"] +stream = ["async-stream", "futures"] [dependencies] +async-stream = { version = "^0.2", optional = true } async-trait = "^0.1" chrono = { version = "^0.4", features = ["serde"] } dirs = "^2.0" +futures = { version = "^0.3", optional = true } log = "^0.4" osproto = "^0.2.0" reqwest = { version = "^0.10", default-features = false, features = ["gzip", "json", "stream"] } diff --git a/examples/list-servers-paginated.rs b/examples/list-servers-paginated.rs new file mode 100644 index 0000000..b82f2cd --- /dev/null +++ b/examples/list-servers-paginated.rs @@ -0,0 +1,71 @@ +// Copyright 2019 Dmitry Tantsur +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::env; +use std::str::FromStr; + +use futures::pin_mut; +use futures::stream::TryStreamExt; +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +pub struct Server { + pub id: String, + pub name: String, +} + +#[derive(Debug, Deserialize)] +pub struct ServersRoot { + pub servers: Vec, +} + +impl From for Vec { + fn from(value: ServersRoot) -> Vec { + value.servers + } +} + +impl osauth::stream::Resource for Server { + type Id = String; + type Root = ServersRoot; + fn resource_id(&self) -> Self::Id { + self.id.clone() + } +} + +#[tokio::main] +async fn main() { + env_logger::init(); + let limit = env::args() + .nth(1) + .map(|s| FromStr::from_str(&s).expect("Expected a number")); + + let session = + osauth::from_env().expect("Failed to create an identity provider from the environment"); + let adapter = session.adapter(osauth::services::COMPUTE); + + let sstream = adapter + .get_json_paginated::<_, Server>(&["servers"], None, limit, None) + .await + .expect("Failed to start a GET request"); + pin_mut!(sstream); + while let Some(srv) = sstream + .try_next() + .await + .expect("Failed to fetch the next chunk") + { + println!("ID = {}, Name = {}", srv.id, srv.name); + } + println!("Done listing"); +} diff --git a/src/adapter.rs b/src/adapter.rs index 56e5bfb..3da759f 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -14,6 +14,8 @@ //! Adapter for a specific service. +#[cfg(feature = "stream")] +use futures::Stream; use reqwest::{Method, RequestBuilder, Response, Url}; use serde::de::DeserializeOwned; use serde::Serialize; @@ -21,6 +23,8 @@ use serde::Serialize; use super::config; use super::request; use super::services::ServiceType; +#[cfg(feature = "stream")] +use super::stream::{paginated, Resource}; use super::{ApiVersion, AuthType, Error, Session}; /// Adapter for a specific service. @@ -349,6 +353,78 @@ impl Adapter { request::fetch_json(self.request(Method::GET, path, api_version).await?).await } + /// Fetch a paginated list of JSON objects using the GET request. + /// + /// ```rust,no_run + /// # async fn example() -> Result<(), osauth::Error> { + /// use futures::pin_mut; + /// use futures::stream::TryStreamExt; + /// use serde::Deserialize; + /// + /// #[derive(Debug, Deserialize)] + /// pub struct Server { + /// pub id: String, + /// pub name: String, + /// } + /// + /// #[derive(Debug, Deserialize)] + /// pub struct ServersRoot { + /// pub servers: Vec, + /// } + /// + /// // This implementatin defines the relationship between the root resource and its items. + /// impl osauth::stream::Resource for Server { + /// type Id = String; + /// type Root = ServersRoot; + /// fn resource_id(&self) -> Self::Id { + /// self.id.clone() + /// } + /// } + /// + /// // This is another required part of the pagination contract. + /// impl From for Vec { + /// fn from(value: ServersRoot) -> Vec { + /// value.servers + /// } + /// } + /// + /// let adapter = osauth::from_env() + /// .expect("Failed to create an identity provider from the environment") + /// .into_adapter(osauth::services::COMPUTE); + /// + /// let servers = adapter + /// .get_json_paginated::<_, Server>(&["servers"], None, None, None) + /// .await?; + /// + /// pin_mut!(servers); + /// while let Some(srv) = servers.try_next().await? { + /// println!("ID = {}, Name = {}", srv.id, srv.name); + /// } + /// # Ok(()) } + /// # #[tokio::main] + /// # async fn main() { example().await.unwrap(); } + /// ``` + /// + /// See [request](#method.request) for an explanation of the parameters. + #[cfg(feature = "stream")] + pub async fn get_json_paginated( + &self, + path: I, + api_version: Option, + limit: Option, + starting_with: Option<::Id>, + ) -> Result>, Error> + where + I: IntoIterator, + I::Item: AsRef, + I::IntoIter: Send, + T: Resource + Unpin, + ::Root: Into> + Send, + { + let builder = self.request(Method::GET, path, api_version).await?; + Ok(paginated(builder, limit, starting_with)) + } + /// Fetch a JSON using the GET request with a query. /// /// See `reqwest` crate documentation for how to define a query. @@ -375,6 +451,34 @@ impl Adapter { .await } + /// Fetch a paginated list of JSON objects using the GET request with a query. + /// + /// See `reqwest` crate documentation for how to define a query. + /// See [request](#method.request) for an explanation of the parameters. + #[cfg(feature = "stream")] + pub async fn get_json_query_paginated( + &self, + path: I, + query: Q, + api_version: Option, + limit: Option, + starting_with: Option<::Id>, + ) -> Result>, Error> + where + I: IntoIterator, + I::Item: AsRef, + I::IntoIter: Send, + Q: Serialize + Send, + T: Resource + Unpin, + ::Root: Into> + Send, + { + let builder = self + .request(Method::GET, path, api_version) + .await? + .query(&query); + Ok(paginated(builder, limit, starting_with)) + } + /// Issue a GET request with a query /// /// See `reqwest` crate documentation for how to define a query. diff --git a/src/lib.rs b/src/lib.rs index a3425b8..6d44748 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -102,6 +102,8 @@ mod protocol; pub mod request; pub mod services; mod session; +#[cfg(feature = "stream")] +pub mod stream; mod url; pub use crate::adapter::Adapter; diff --git a/src/session.rs b/src/session.rs index b7a0d9b..560dcfc 100644 --- a/src/session.rs +++ b/src/session.rs @@ -16,6 +16,8 @@ use std::sync::Arc; +#[cfg(feature = "stream")] +use futures::Stream; use log::{debug, trace}; use reqwest::header::HeaderMap; use reqwest::{Method, RequestBuilder, Response, Url}; @@ -26,6 +28,8 @@ use super::cache; use super::protocol::ServiceInfo; use super::request; use super::services::ServiceType; +#[cfg(feature = "stream")] +use super::stream::{paginated, Resource}; use super::url; use super::{Adapter, ApiVersion, AuthType, Error}; @@ -422,6 +426,87 @@ impl Session { .await } + /// Fetch a paginated list of JSON objects using the GET request. + /// + /// ```rust,no_run + /// # async fn example() -> Result<(), osauth::Error> { + /// use futures::pin_mut; + /// use futures::stream::TryStreamExt; + /// use serde::Deserialize; + /// + /// #[derive(Debug, Deserialize)] + /// pub struct Server { + /// pub id: String, + /// pub name: String, + /// } + /// + /// #[derive(Debug, Deserialize)] + /// pub struct ServersRoot { + /// pub servers: Vec, + /// } + /// + /// // This implementatin defines the relationship between the root resource and its items. + /// impl osauth::stream::Resource for Server { + /// type Id = String; + /// type Root = ServersRoot; + /// fn resource_id(&self) -> Self::Id { + /// self.id.clone() + /// } + /// } + /// + /// // This is another required part of the pagination contract. + /// impl From for Vec { + /// fn from(value: ServersRoot) -> Vec { + /// value.servers + /// } + /// } + /// + /// let session = + /// osauth::from_env().expect("Failed to create an identity provider from the environment"); + /// + /// let servers = session + /// .get_json_paginated::<_, _, Server>( + /// osauth::services::COMPUTE, + /// &["servers"], + /// None, + /// None, + /// None + /// ) + /// .await?; + /// + /// pin_mut!(servers); + /// while let Some(srv) = servers.try_next().await? { + /// println!("ID = {}, Name = {}", srv.id, srv.name); + /// } + /// # Ok(()) } + /// # #[tokio::main] + /// # async fn main() { example().await.unwrap(); } + /// ``` + /// + /// See [request](#method.request) for an explanation of the parameters. + #[cfg(feature = "stream")] + pub async fn get_json_paginated( + &self, + service: Srv, + path: I, + api_version: Option, + limit: Option, + starting_with: Option<::Id>, + ) -> Result>, Error> + where + Srv: ServiceType + Send + Clone, + I: IntoIterator, + I::Item: AsRef, + I::IntoIter: Send, + T: Resource + Unpin, + ::Root: Into> + Send, + { + let builder = self + .request(service, Method::GET, path, api_version) + .await?; + Ok(paginated(builder, limit, starting_with)) + } + /// Fetch a JSON using the GET request with a query. /// /// See `reqwest` crate documentation for how to define a query. @@ -450,6 +535,36 @@ impl Session { .await } + /// Fetch a paginated list of JSON objects using the GET request with a query. + /// + /// See `reqwest` crate documentation for how to define a query. + /// See [request](#method.request) for an explanation of the parameters. + #[cfg(feature = "stream")] + pub async fn get_json_query_paginated( + &self, + service: Srv, + path: I, + query: Q, + api_version: Option, + limit: Option, + starting_with: Option<::Id>, + ) -> Result>, Error> + where + Srv: ServiceType + Send + Clone, + I: IntoIterator, + I::Item: AsRef, + I::IntoIter: Send, + Q: Serialize + Send, + T: Resource + Unpin, + ::Root: Into> + Send, + { + let builder = self + .request(service, Method::GET, path, api_version) + .await? + .query(&query); + Ok(paginated(builder, limit, starting_with)) + } + /// Issue a GET request with a query /// /// See `reqwest` crate documentation for how to define a query. diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..dfd530a --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,101 @@ +// Copyright 2019 Dmitry Tantsur +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! A stream of resources. + +use std::fmt::Debug; + +use async_stream::try_stream; +use futures::pin_mut; +use futures::stream::{Stream, TryStreamExt}; +use reqwest::RequestBuilder; +use serde::de::DeserializeOwned; +use serde::Serialize; + +use super::request; +use super::Error; + +/// A single resource. +pub trait Resource { + /// Type of an ID. + type Id: Debug + Serialize; + + /// Root type of the listing. + type Root: DeserializeOwned; + + /// Retrieve a copy of the ID. + fn resource_id(&self) -> Self::Id; +} + +#[derive(Serialize)] +struct Query { + #[serde(skip_serializing_if = "Option::is_none")] + limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + marker: Option, +} + +fn chunks( + builder: RequestBuilder, + limit: Option, + starting_with: Option, +) -> impl Stream, Error>> +where + T: Resource + Unpin, + T::Root: Into> + Send, +{ + let mut marker = starting_with; + + try_stream! { + loop { + let prepared = builder + .try_clone() + .expect("Builder with a streaming body cannot be used") + .query(&Query{ limit: limit, marker: marker.take() }); + let result: T::Root = request::fetch_json(prepared).await?; + let items = result.into(); + if let Some(new_m) = items.last() { + marker = Some(new_m.resource_id()); + yield items; + } else { + break + } + } + } +} + +/// Creates a paginated resource stream. +/// +/// # Panics +/// +/// Will panic during iteration if the request builder has a streaming body. +pub fn paginated( + builder: RequestBuilder, + limit: Option, + starting_with: Option, +) -> impl Stream> +where + T: Resource + Unpin, + T::Root: Into> + Send, +{ + try_stream! { + let iter = chunks(builder, limit, starting_with); + pin_mut!(iter); + while let Some(chunk) = iter.try_next().await? { + for item in chunk { + yield item; + } + } + } +}