Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(iroh): add more rpc methods #1962

Merged
merged 5 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 242 additions & 23 deletions iroh/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,19 @@ use tracing::warn;

use crate::rpc_protocol::{
AuthorCreateRequest, AuthorListRequest, BlobAddPathRequest, BlobAddStreamRequest,
BlobAddStreamUpdate, BlobDeleteBlobRequest, BlobDownloadRequest, BlobListCollectionsRequest,
BlobListCollectionsResponse, BlobListIncompleteRequest, BlobListIncompleteResponse,
BlobListRequest, BlobListResponse, BlobReadRequest, BlobReadResponse, BlobValidateRequest,
CounterStats, CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest,
DocCloseRequest, DocCreateRequest, DocDelRequest, DocDelResponse, DocDropRequest,
DocExportFileRequest, DocExportProgress, DocGetDownloadPolicyRequest, DocGetExactRequest,
DocGetManyRequest, DocImportFileRequest, DocImportProgress, DocImportRequest, DocLeaveRequest,
DocListRequest, DocOpenRequest, DocSetDownloadPolicyRequest, DocSetHashRequest, DocSetRequest,
DocShareRequest, DocStartSyncRequest, DocStatusRequest, DocSubscribeRequest, DocTicket,
DownloadProgress, ListTagsRequest, ListTagsResponse, NodeConnectionInfoRequest,
NodeConnectionInfoResponse, NodeConnectionsRequest, NodeShutdownRequest, NodeStatsRequest,
NodeStatusRequest, NodeStatusResponse, ProviderService, SetTagOption, ShareMode, WrapOption,
BlobAddStreamUpdate, BlobDeleteBlobRequest, BlobDownloadRequest, BlobGetCollectionRequest,
BlobGetCollectionResponse, BlobListCollectionsRequest, BlobListCollectionsResponse,
BlobListIncompleteRequest, BlobListIncompleteResponse, BlobListRequest, BlobListResponse,
BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, CounterStats,
CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, DocCloseRequest,
DocCreateRequest, DocDelRequest, DocDelResponse, DocDropRequest, DocExportFileRequest,
DocExportProgress, DocGetDownloadPolicyRequest, DocGetExactRequest, DocGetManyRequest,
DocImportFileRequest, DocImportProgress, DocImportRequest, DocLeaveRequest, DocListRequest,
DocOpenRequest, DocSetDownloadPolicyRequest, DocSetHashRequest, DocSetRequest, DocShareRequest,
DocStartSyncRequest, DocStatusRequest, DocSubscribeRequest, DocTicket, DownloadProgress,
ListTagsRequest, ListTagsResponse, NodeConnectionInfoRequest, NodeConnectionInfoResponse,
NodeConnectionsRequest, NodeShutdownRequest, NodeStatsRequest, NodeStatusRequest,
NodeStatusResponse, ProviderService, SetTagOption, ShareMode, WrapOption,
};
use crate::sync_engine::SyncEvent;

Expand Down Expand Up @@ -240,7 +241,14 @@ where
///
/// Returns a [`BlobReader`], which can report the size of the blob before reading it.
pub async fn read(&self, hash: Hash) -> Result<BlobReader> {
BlobReader::from_rpc(&self.rpc, hash).await
BlobReader::from_rpc_read(&self.rpc, hash).await
}

/// Read offset + len from a single blob.
///
/// If `len` is `None` it will read the full blob.
pub async fn read_at(&self, hash: Hash, offset: u64, len: Option<usize>) -> Result<BlobReader> {
BlobReader::from_rpc_read_at(&self.rpc, hash, offset, len).await
}

/// Read all bytes of single blob.
Expand All @@ -249,7 +257,22 @@ where
/// reading is small. If not sure, use [`Self::read`] and check the size with
/// [`BlobReader::size`] before calling [`BlobReader::read_to_bytes`].
pub async fn read_to_bytes(&self, hash: Hash) -> Result<Bytes> {
BlobReader::from_rpc(&self.rpc, hash)
BlobReader::from_rpc_read(&self.rpc, hash)
.await?
.read_to_bytes()
.await
}

/// Read all bytes of single blob at `offset` for length `len`.
///
/// This allocates a buffer for the full length.
pub async fn read_at_to_bytes(
&self,
hash: Hash,
offset: u64,
len: Option<usize>,
) -> Result<Bytes> {
BlobReader::from_rpc_read_at(&self.rpc, hash, offset, len)
.await?
.read_to_bytes()
.await
Expand Down Expand Up @@ -387,6 +410,13 @@ where
Ok(stream.map_err(anyhow::Error::from))
}

/// Read the content of a collection.
pub async fn get_collection(&self, hash: Hash) -> Result<Collection> {
let BlobGetCollectionResponse { collection } =
self.rpc.rpc(BlobGetCollectionRequest { hash }).await??;
Ok(collection)
}

/// List all collections.
pub async fn list_collections(
&self,
Expand Down Expand Up @@ -485,38 +515,58 @@ impl Stream for BlobAddProgress {
#[derive(derive_more::Debug)]
pub struct BlobReader {
size: u64,
response_size: u64,
is_complete: bool,
#[debug("StreamReader")]
stream: tokio_util::io::StreamReader<BoxStream<'static, io::Result<Bytes>>, Bytes>,
}

impl BlobReader {
fn new(size: u64, is_complete: bool, stream: BoxStream<'static, io::Result<Bytes>>) -> Self {
fn new(
size: u64,
response_size: u64,
is_complete: bool,
stream: BoxStream<'static, io::Result<Bytes>>,
) -> Self {
Self {
size,
response_size,
is_complete,
stream: StreamReader::new(stream),
}
}

async fn from_rpc<C: ServiceConnection<ProviderService>>(
async fn from_rpc_read<C: ServiceConnection<ProviderService>>(
rpc: &RpcClient<ProviderService, C>,
hash: Hash,
) -> anyhow::Result<Self> {
let stream = rpc.server_streaming(BlobReadRequest { hash }).await?;
Self::from_rpc_read_at(rpc, hash, 0, None).await
}

async fn from_rpc_read_at<C: ServiceConnection<ProviderService>>(
rpc: &RpcClient<ProviderService, C>,
hash: Hash,
offset: u64,
len: Option<usize>,
) -> anyhow::Result<Self> {
let stream = rpc
.server_streaming(BlobReadAtRequest { hash, offset, len })
.await?;
let mut stream = flatten(stream);

let (size, is_complete) = match stream.next().await {
Some(Ok(BlobReadResponse::Entry { size, is_complete })) => (size, is_complete),
Some(Ok(BlobReadAtResponse::Entry { size, is_complete })) => (size, is_complete),
Some(Err(err)) => return Err(err),
None | Some(Ok(_)) => return Err(anyhow!("Expected header frame")),
};

let stream = stream.map(|item| match item {
Ok(BlobReadResponse::Data { chunk }) => Ok(chunk),
Ok(BlobReadAtResponse::Data { chunk }) => Ok(chunk),
Ok(_) => Err(io::Error::new(io::ErrorKind::Other, "Expected data frame")),
Err(err) => Err(io::Error::new(io::ErrorKind::Other, format!("{err}"))),
});
Ok(Self::new(size, is_complete, stream.boxed()))
let len = len.map(|l| l as u64).unwrap_or_else(|| size - offset);
Ok(Self::new(size, len, is_complete, stream.boxed()))
}

/// Total size of this blob.
Expand All @@ -533,7 +583,7 @@ impl BlobReader {

/// Read all bytes of the blob.
pub async fn read_to_bytes(&mut self) -> anyhow::Result<Bytes> {
let mut buf = Vec::with_capacity(self.size() as usize);
let mut buf = Vec::with_capacity(self.response_size as usize);
self.read_to_end(&mut buf).await?;
Ok(buf.into())
}
Expand Down Expand Up @@ -908,7 +958,7 @@ impl Entry {
where
C: ServiceConnection<ProviderService>,
{
BlobReader::from_rpc(client.into(), self.content_hash()).await
BlobReader::from_rpc_read(client.into(), self.content_hash()).await
}

/// Read all content of an [`Entry`] into a buffer.
Expand All @@ -921,7 +971,7 @@ impl Entry {
where
C: ServiceConnection<ProviderService>,
{
BlobReader::from_rpc(client.into(), self.content_hash())
BlobReader::from_rpc_read(client.into(), self.content_hash())
.await?
.read_to_bytes()
.await
Expand Down Expand Up @@ -1321,4 +1371,173 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_blob_read_at() -> Result<()> {
// let _guard = iroh_test::logging::setup();

let doc_store = iroh_sync::store::memory::Store::default();
let db = iroh_bytes::store::mem::Store::new();
let node = crate::node::Node::builder(db, doc_store).spawn().await?;

// create temp file
let temp_dir = tempfile::tempdir().context("tempdir")?;

let in_root = temp_dir.path().join("in");
tokio::fs::create_dir_all(in_root.clone())
.await
.context("create dir all")?;

let path = in_root.join("test-blob");
let size = 1024 * 128;
let buf: Vec<u8> = (0..size).map(|i| i as u8).collect();
let mut file = tokio::fs::File::create(path.clone())
.await
.context("create file")?;
file.write_all(&buf.clone()).await.context("write_all")?;
file.flush().await.context("flush")?;

let client = node.client();

let import_outcome = client
.blobs
.add_from_path(
path.to_path_buf(),
false,
SetTagOption::Auto,
WrapOption::NoWrap,
)
.await
.context("import file")?
.finish()
.await
.context("import finish")?;

let hash = import_outcome.hash;

// Read everything
let res = client.blobs.read_to_bytes(hash).await?;
assert_eq!(&res, &buf[..]);

// Read at smaller than blob_get_chunk_size
let res = client.blobs.read_at_to_bytes(hash, 0, Some(100)).await?;
assert_eq!(res.len(), 100);
assert_eq!(&res[..], &buf[0..100]);

let res = client.blobs.read_at_to_bytes(hash, 20, Some(120)).await?;
assert_eq!(res.len(), 120);
assert_eq!(&res[..], &buf[20..140]);

// Read at equal to blob_get_chunk_size
let res = client
.blobs
.read_at_to_bytes(hash, 0, Some(1024 * 64))
.await?;
assert_eq!(res.len(), 1024 * 64);
assert_eq!(&res[..], &buf[0..1024 * 64]);

let res = client
.blobs
.read_at_to_bytes(hash, 20, Some(1024 * 64))
.await?;
assert_eq!(res.len(), 1024 * 64);
assert_eq!(&res[..], &buf[20..(20 + 1024 * 64)]);

// Read at larger than blob_get_chunk_size
let res = client
.blobs
.read_at_to_bytes(hash, 0, Some(10 + 1024 * 64))
.await?;
assert_eq!(res.len(), 10 + 1024 * 64);
assert_eq!(&res[..], &buf[0..(10 + 1024 * 64)]);

let res = client
.blobs
.read_at_to_bytes(hash, 20, Some(10 + 1024 * 64))
.await?;
assert_eq!(res.len(), 10 + 1024 * 64);
assert_eq!(&res[..], &buf[20..(20 + 10 + 1024 * 64)]);

// full length
let res = client.blobs.read_at_to_bytes(hash, 20, None).await?;
assert_eq!(res.len(), 1024 * 128 - 20);
assert_eq!(&res[..], &buf[20..]);

// size should be total
let reader = client.blobs.read_at(hash, 0, Some(20)).await?;
assert_eq!(reader.size(), 1024 * 128);
assert_eq!(reader.response_size, 20);

Ok(())
}

#[tokio::test]
async fn test_blob_get_collection() -> Result<()> {
let _guard = iroh_test::logging::setup();

let doc_store = iroh_sync::store::memory::Store::default();
let db = iroh_bytes::store::mem::Store::new();
let node = crate::node::Node::builder(db, doc_store).spawn().await?;

// create temp file
let temp_dir = tempfile::tempdir().context("tempdir")?;

let in_root = temp_dir.path().join("in");
tokio::fs::create_dir_all(in_root.clone())
.await
.context("create dir all")?;

let mut paths = Vec::new();
for i in 0..5 {
let path = in_root.join(format!("test-{i}"));
let size = 100;
let mut buf = vec![0u8; size];
rand::thread_rng().fill_bytes(&mut buf);
let mut file = tokio::fs::File::create(path.clone())
.await
.context("create file")?;
file.write_all(&buf.clone()).await.context("write_all")?;
file.flush().await.context("flush")?;
paths.push(path);
}

let client = node.client();

let mut collection = Collection::default();
let mut tags = Vec::new();
// import files
for path in &paths {
let import_outcome = client
.blobs
.add_from_path(
path.to_path_buf(),
false,
SetTagOption::Auto,
WrapOption::NoWrap,
)
.await
.context("import file")?
.finish()
.await
.context("import finish")?;

collection.push(
path.file_name().unwrap().to_str().unwrap().to_string(),
import_outcome.hash,
);
tags.push(import_outcome.tag);
}

let (hash, _tag) = client
.blobs
.create_collection(collection, SetTagOption::Auto, tags)
.await?;

let collection = client.blobs.get_collection(hash).await?;

// 5 blobs
assert_eq!(collection.len(), 5);

Ok(())
}
}
Loading
Loading