diff --git a/iroh/src/client.rs b/iroh/src/client.rs index de22cdd516..c08351fcc3 100644 --- a/iroh/src/client.rs +++ b/iroh/src/client.rs @@ -35,18 +35,19 @@ use tracing::warn; use crate::rpc_protocol::{ AuthorCreateRequest, AuthorListRequest, BlobAddPathRequest, BlobAddStreamRequest, - BlobAddStreamUpdate, BlobDeleteBlobRequest, BlobDownloadRequest, BlobListCollectionsRequest, - BlobListCollectionsResponse, BlobListIncompleteRequest, BlobListIncompleteResponse, - BlobListRequest, BlobListResponse, BlobReadRequest, BlobReadResponse, BlobValidateRequest, - CounterStats, CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, - DocCloseRequest, DocCreateRequest, DocDelRequest, DocDelResponse, DocDropRequest, - DocExportFileRequest, DocExportProgress, DocGetDownloadPolicyRequest, DocGetExactRequest, - DocGetManyRequest, DocImportFileRequest, DocImportProgress, DocImportRequest, DocLeaveRequest, - DocListRequest, DocOpenRequest, DocSetDownloadPolicyRequest, DocSetHashRequest, DocSetRequest, - DocShareRequest, DocStartSyncRequest, DocStatusRequest, DocSubscribeRequest, DocTicket, - DownloadProgress, ListTagsRequest, ListTagsResponse, NodeConnectionInfoRequest, - NodeConnectionInfoResponse, NodeConnectionsRequest, NodeShutdownRequest, NodeStatsRequest, - NodeStatusRequest, NodeStatusResponse, ProviderService, SetTagOption, ShareMode, WrapOption, + BlobAddStreamUpdate, BlobDeleteBlobRequest, BlobDownloadRequest, BlobGetCollectionRequest, + BlobGetCollectionResponse, BlobListCollectionsRequest, BlobListCollectionsResponse, + BlobListIncompleteRequest, BlobListIncompleteResponse, BlobListRequest, BlobListResponse, + BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, CounterStats, + CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, DocCloseRequest, + DocCreateRequest, DocDelRequest, DocDelResponse, DocDropRequest, DocExportFileRequest, + DocExportProgress, DocGetDownloadPolicyRequest, DocGetExactRequest, DocGetManyRequest, + DocImportFileRequest, DocImportProgress, DocImportRequest, DocLeaveRequest, DocListRequest, + DocOpenRequest, DocSetDownloadPolicyRequest, DocSetHashRequest, DocSetRequest, DocShareRequest, + DocStartSyncRequest, DocStatusRequest, DocSubscribeRequest, DocTicket, DownloadProgress, + ListTagsRequest, ListTagsResponse, NodeConnectionInfoRequest, NodeConnectionInfoResponse, + NodeConnectionsRequest, NodeShutdownRequest, NodeStatsRequest, NodeStatusRequest, + NodeStatusResponse, ProviderService, SetTagOption, ShareMode, WrapOption, }; use crate::sync_engine::SyncEvent; @@ -240,7 +241,14 @@ where /// /// Returns a [`BlobReader`], which can report the size of the blob before reading it. pub async fn read(&self, hash: Hash) -> Result { - BlobReader::from_rpc(&self.rpc, hash).await + BlobReader::from_rpc_read(&self.rpc, hash).await + } + + /// Read offset + len from a single blob. + /// + /// If `len` is `None` it will read the full blob. + pub async fn read_at(&self, hash: Hash, offset: u64, len: Option) -> Result { + BlobReader::from_rpc_read_at(&self.rpc, hash, offset, len).await } /// Read all bytes of single blob. @@ -249,7 +257,22 @@ where /// reading is small. If not sure, use [`Self::read`] and check the size with /// [`BlobReader::size`] before calling [`BlobReader::read_to_bytes`]. pub async fn read_to_bytes(&self, hash: Hash) -> Result { - BlobReader::from_rpc(&self.rpc, hash) + BlobReader::from_rpc_read(&self.rpc, hash) + .await? + .read_to_bytes() + .await + } + + /// Read all bytes of single blob at `offset` for length `len`. + /// + /// This allocates a buffer for the full length. + pub async fn read_at_to_bytes( + &self, + hash: Hash, + offset: u64, + len: Option, + ) -> Result { + BlobReader::from_rpc_read_at(&self.rpc, hash, offset, len) .await? .read_to_bytes() .await @@ -387,6 +410,13 @@ where Ok(stream.map_err(anyhow::Error::from)) } + /// Read the content of a collection. + pub async fn get_collection(&self, hash: Hash) -> Result { + let BlobGetCollectionResponse { collection } = + self.rpc.rpc(BlobGetCollectionRequest { hash }).await??; + Ok(collection) + } + /// List all collections. pub async fn list_collections( &self, @@ -485,38 +515,58 @@ impl Stream for BlobAddProgress { #[derive(derive_more::Debug)] pub struct BlobReader { size: u64, + response_size: u64, is_complete: bool, #[debug("StreamReader")] stream: tokio_util::io::StreamReader>, Bytes>, } + impl BlobReader { - fn new(size: u64, is_complete: bool, stream: BoxStream<'static, io::Result>) -> Self { + fn new( + size: u64, + response_size: u64, + is_complete: bool, + stream: BoxStream<'static, io::Result>, + ) -> Self { Self { size, + response_size, is_complete, stream: StreamReader::new(stream), } } - async fn from_rpc>( + async fn from_rpc_read>( rpc: &RpcClient, hash: Hash, ) -> anyhow::Result { - let stream = rpc.server_streaming(BlobReadRequest { hash }).await?; + Self::from_rpc_read_at(rpc, hash, 0, None).await + } + + async fn from_rpc_read_at>( + rpc: &RpcClient, + hash: Hash, + offset: u64, + len: Option, + ) -> anyhow::Result { + let stream = rpc + .server_streaming(BlobReadAtRequest { hash, offset, len }) + .await?; let mut stream = flatten(stream); let (size, is_complete) = match stream.next().await { - Some(Ok(BlobReadResponse::Entry { size, is_complete })) => (size, is_complete), + Some(Ok(BlobReadAtResponse::Entry { size, is_complete })) => (size, is_complete), Some(Err(err)) => return Err(err), None | Some(Ok(_)) => return Err(anyhow!("Expected header frame")), }; let stream = stream.map(|item| match item { - Ok(BlobReadResponse::Data { chunk }) => Ok(chunk), + Ok(BlobReadAtResponse::Data { chunk }) => Ok(chunk), Ok(_) => Err(io::Error::new(io::ErrorKind::Other, "Expected data frame")), Err(err) => Err(io::Error::new(io::ErrorKind::Other, format!("{err}"))), }); - Ok(Self::new(size, is_complete, stream.boxed())) + let len = len.map(|l| l as u64).unwrap_or_else(|| size - offset); + Ok(Self::new(size, len, is_complete, stream.boxed())) } /// Total size of this blob. @@ -533,7 +583,7 @@ impl BlobReader { /// Read all bytes of the blob. pub async fn read_to_bytes(&mut self) -> anyhow::Result { - let mut buf = Vec::with_capacity(self.size() as usize); + let mut buf = Vec::with_capacity(self.response_size as usize); self.read_to_end(&mut buf).await?; Ok(buf.into()) } @@ -908,7 +958,7 @@ impl Entry { where C: ServiceConnection, { - BlobReader::from_rpc(client.into(), self.content_hash()).await + BlobReader::from_rpc_read(client.into(), self.content_hash()).await } /// Read all content of an [`Entry`] into a buffer. @@ -921,7 +971,7 @@ impl Entry { where C: ServiceConnection, { - BlobReader::from_rpc(client.into(), self.content_hash()) + BlobReader::from_rpc_read(client.into(), self.content_hash()) .await? .read_to_bytes() .await @@ -1321,4 +1371,173 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_blob_read_at() -> Result<()> { + // let _guard = iroh_test::logging::setup(); + + let doc_store = iroh_sync::store::memory::Store::default(); + let db = iroh_bytes::store::mem::Store::new(); + let node = crate::node::Node::builder(db, doc_store).spawn().await?; + + // create temp file + let temp_dir = tempfile::tempdir().context("tempdir")?; + + let in_root = temp_dir.path().join("in"); + tokio::fs::create_dir_all(in_root.clone()) + .await + .context("create dir all")?; + + let path = in_root.join("test-blob"); + let size = 1024 * 128; + let buf: Vec = (0..size).map(|i| i as u8).collect(); + let mut file = tokio::fs::File::create(path.clone()) + .await + .context("create file")?; + file.write_all(&buf.clone()).await.context("write_all")?; + file.flush().await.context("flush")?; + + let client = node.client(); + + let import_outcome = client + .blobs + .add_from_path( + path.to_path_buf(), + false, + SetTagOption::Auto, + WrapOption::NoWrap, + ) + .await + .context("import file")? + .finish() + .await + .context("import finish")?; + + let hash = import_outcome.hash; + + // Read everything + let res = client.blobs.read_to_bytes(hash).await?; + assert_eq!(&res, &buf[..]); + + // Read at smaller than blob_get_chunk_size + let res = client.blobs.read_at_to_bytes(hash, 0, Some(100)).await?; + assert_eq!(res.len(), 100); + assert_eq!(&res[..], &buf[0..100]); + + let res = client.blobs.read_at_to_bytes(hash, 20, Some(120)).await?; + assert_eq!(res.len(), 120); + assert_eq!(&res[..], &buf[20..140]); + + // Read at equal to blob_get_chunk_size + let res = client + .blobs + .read_at_to_bytes(hash, 0, Some(1024 * 64)) + .await?; + assert_eq!(res.len(), 1024 * 64); + assert_eq!(&res[..], &buf[0..1024 * 64]); + + let res = client + .blobs + .read_at_to_bytes(hash, 20, Some(1024 * 64)) + .await?; + assert_eq!(res.len(), 1024 * 64); + assert_eq!(&res[..], &buf[20..(20 + 1024 * 64)]); + + // Read at larger than blob_get_chunk_size + let res = client + .blobs + .read_at_to_bytes(hash, 0, Some(10 + 1024 * 64)) + .await?; + assert_eq!(res.len(), 10 + 1024 * 64); + assert_eq!(&res[..], &buf[0..(10 + 1024 * 64)]); + + let res = client + .blobs + .read_at_to_bytes(hash, 20, Some(10 + 1024 * 64)) + .await?; + assert_eq!(res.len(), 10 + 1024 * 64); + assert_eq!(&res[..], &buf[20..(20 + 10 + 1024 * 64)]); + + // full length + let res = client.blobs.read_at_to_bytes(hash, 20, None).await?; + assert_eq!(res.len(), 1024 * 128 - 20); + assert_eq!(&res[..], &buf[20..]); + + // size should be total + let reader = client.blobs.read_at(hash, 0, Some(20)).await?; + assert_eq!(reader.size(), 1024 * 128); + assert_eq!(reader.response_size, 20); + + Ok(()) + } + + #[tokio::test] + async fn test_blob_get_collection() -> Result<()> { + let _guard = iroh_test::logging::setup(); + + let doc_store = iroh_sync::store::memory::Store::default(); + let db = iroh_bytes::store::mem::Store::new(); + let node = crate::node::Node::builder(db, doc_store).spawn().await?; + + // create temp file + let temp_dir = tempfile::tempdir().context("tempdir")?; + + let in_root = temp_dir.path().join("in"); + tokio::fs::create_dir_all(in_root.clone()) + .await + .context("create dir all")?; + + let mut paths = Vec::new(); + for i in 0..5 { + let path = in_root.join(format!("test-{i}")); + let size = 100; + let mut buf = vec![0u8; size]; + rand::thread_rng().fill_bytes(&mut buf); + let mut file = tokio::fs::File::create(path.clone()) + .await + .context("create file")?; + file.write_all(&buf.clone()).await.context("write_all")?; + file.flush().await.context("flush")?; + paths.push(path); + } + + let client = node.client(); + + let mut collection = Collection::default(); + let mut tags = Vec::new(); + // import files + for path in &paths { + let import_outcome = client + .blobs + .add_from_path( + path.to_path_buf(), + false, + SetTagOption::Auto, + WrapOption::NoWrap, + ) + .await + .context("import file")? + .finish() + .await + .context("import finish")?; + + collection.push( + path.file_name().unwrap().to_str().unwrap().to_string(), + import_outcome.hash, + ); + tags.push(import_outcome.tag); + } + + let (hash, _tag) = client + .blobs + .create_collection(collection, SetTagOption::Auto, tags) + .await?; + + let collection = client.blobs.get_collection(hash).await?; + + // 5 blobs + assert_eq!(collection.len(), 5); + + Ok(()) + } } diff --git a/iroh/src/node.rs b/iroh/src/node.rs index b6ffe671ff..3926a49d80 100644 --- a/iroh/src/node.rs +++ b/iroh/src/node.rs @@ -54,12 +54,13 @@ use url::Url; use crate::downloader::Downloader; use crate::rpc_protocol::{ BlobAddPathRequest, BlobAddPathResponse, BlobAddStreamRequest, BlobAddStreamResponse, - BlobAddStreamUpdate, BlobDeleteBlobRequest, BlobDownloadRequest, BlobListCollectionsRequest, - BlobListCollectionsResponse, BlobListIncompleteRequest, BlobListIncompleteResponse, - BlobListRequest, BlobListResponse, BlobReadRequest, BlobReadResponse, BlobValidateRequest, - CreateCollectionRequest, CreateCollectionResponse, DeleteTagRequest, DocExportFileRequest, - DocExportFileResponse, DocExportProgress, DocImportFileRequest, DocImportFileResponse, - DocImportProgress, DocSetHashRequest, DownloadLocation, ListTagsRequest, ListTagsResponse, + BlobAddStreamUpdate, BlobDeleteBlobRequest, BlobDownloadRequest, BlobGetCollectionRequest, + BlobGetCollectionResponse, BlobListCollectionsRequest, BlobListCollectionsResponse, + BlobListIncompleteRequest, BlobListIncompleteResponse, BlobListRequest, BlobListResponse, + BlobReadAtRequest, BlobReadAtResponse, BlobValidateRequest, CreateCollectionRequest, + CreateCollectionResponse, DeleteTagRequest, DocExportFileRequest, DocExportFileResponse, + DocExportProgress, DocImportFileRequest, DocImportFileResponse, DocImportProgress, + DocSetHashRequest, DownloadLocation, ListTagsRequest, ListTagsResponse, NodeConnectionInfoRequest, NodeConnectionInfoResponse, NodeConnectionsRequest, NodeConnectionsResponse, NodeShutdownRequest, NodeStatsRequest, NodeStatsResponse, NodeStatusRequest, NodeStatusResponse, NodeWatchRequest, NodeWatchResponse, ProviderRequest, @@ -1436,42 +1437,69 @@ impl RpcHandler { Ok(()) } - fn blob_read( + fn blob_read_at( self, - req: BlobReadRequest, - ) -> impl Stream> + Send + 'static { + req: BlobReadAtRequest, + ) -> impl Stream> + Send + 'static { let (tx, rx) = flume::bounded(RPC_BLOB_GET_CHANNEL_CAP); let entry = self.inner.db.get(&req.hash); self.inner.rt.spawn_pinned(move || async move { - if let Err(err) = read_loop(entry, tx.clone(), RPC_BLOB_GET_CHUNK_SIZE).await { + if let Err(err) = read_loop( + req.offset, + req.len, + entry, + tx.clone(), + RPC_BLOB_GET_CHUNK_SIZE, + ) + .await + { tx.send_async(RpcResult::Err(err.into())).await.ok(); } }); async fn read_loop( + offset: u64, + len: Option, entry: Option>, - tx: flume::Sender>, - chunk_size: usize, + tx: flume::Sender>, + max_chunk_size: usize, ) -> anyhow::Result<()> { let entry = entry.ok_or_else(|| anyhow!("Blob not found"))?; let size = entry.size(); - tx.send_async(Ok(BlobReadResponse::Entry { + tx.send_async(Ok(BlobReadAtResponse::Entry { size, is_complete: entry.is_complete(), })) .await?; let mut reader = entry.data_reader().await?; - let mut offset = 0u64; - loop { - let chunk = reader.read_at(offset, chunk_size).await?; - let len = chunk.len(); + + let len = len.unwrap_or_else(|| (size - offset) as usize); + + let (num_chunks, chunk_size) = if len <= max_chunk_size { + (1, len) + } else { + let num_chunks = len / max_chunk_size + (len % max_chunk_size != 0) as usize; + (num_chunks, max_chunk_size) + }; + + let mut read = 0u64; + for i in 0..num_chunks { + let chunk_size = if i == num_chunks - 1 { + // last chunk might be smaller + len - read as usize + } else { + chunk_size + }; + let chunk = reader.read_at(offset + read, chunk_size).await?; + let chunk_len = chunk.len(); if !chunk.is_empty() { - tx.send_async(Ok(BlobReadResponse::Data { chunk })).await?; + tx.send_async(Ok(BlobReadAtResponse::Data { chunk })) + .await?; } - if len < chunk_size { + if chunk_len < chunk_size { break; } else { - offset += len as u64; + read += chunk_len as u64; } } Ok(()) @@ -1543,6 +1571,21 @@ impl RpcHandler { Ok(CreateCollectionResponse { hash, tag }) } + + async fn blob_get_collection( + self, + req: BlobGetCollectionRequest, + ) -> RpcResult { + let hash = req.hash; + let db = self.inner.db.clone(); + let collection = self + .rt() + .spawn_pinned(move || async move { Collection::load(&db, &hash).await }) + .await + .map_err(|_| anyhow!("join failed"))??; + + Ok(BlobGetCollectionResponse { collection }) + } } fn handle_rpc_request>( @@ -1583,6 +1626,10 @@ fn handle_rpc_request>( .await } CreateCollection(msg) => chan.rpc(msg, handler, RpcHandler::create_collection).await, + BlobGetCollection(msg) => { + chan.rpc(msg, handler, RpcHandler::blob_get_collection) + .await + } ListTags(msg) => { chan.server_streaming(msg, handler, RpcHandler::blob_list_tags) .await @@ -1601,8 +1648,8 @@ fn handle_rpc_request>( chan.server_streaming(msg, handler, RpcHandler::blob_validate) .await } - BlobRead(msg) => { - chan.server_streaming(msg, handler, RpcHandler::blob_read) + BlobReadAt(msg) => { + chan.server_streaming(msg, handler, RpcHandler::blob_read_at) .await } BlobAddStream(msg) => { diff --git a/iroh/src/rpc_protocol.rs b/iroh/src/rpc_protocol.rs index 8464b94dc5..1e6ab8810d 100644 --- a/iroh/src/rpc_protocol.rs +++ b/iroh/src/rpc_protocol.rs @@ -274,6 +274,24 @@ impl RpcMsg for DeleteTagRequest { type Response = RpcResult<()>; } +/// Get a collection +#[derive(Debug, Serialize, Deserialize)] +pub struct BlobGetCollectionRequest { + /// Hash of the collection + pub hash: Hash, +} + +impl RpcMsg for BlobGetCollectionRequest { + type Response = RpcResult; +} + +/// The response for a `BlobGetCollectionRequest`. +#[derive(Debug, Serialize, Deserialize)] +pub struct BlobGetCollectionResponse { + /// The collection. + pub collection: Collection, +} + /// Create a collection. #[derive(Debug, Serialize, Deserialize)] pub struct CreateCollectionRequest { @@ -937,22 +955,26 @@ pub struct DocGetDownloadPolicyResponse { /// Get the bytes for a hash #[derive(Serialize, Deserialize, Debug)] -pub struct BlobReadRequest { +pub struct BlobReadAtRequest { /// Hash to get bytes for pub hash: Hash, + /// Offset to start reading at + pub offset: u64, + /// Lenghth of the data to get + pub len: Option, } -impl Msg for BlobReadRequest { +impl Msg for BlobReadAtRequest { type Pattern = ServerStreaming; } -impl ServerStreamingMsg for BlobReadRequest { - type Response = RpcResult; +impl ServerStreamingMsg for BlobReadAtRequest { + type Response = RpcResult; } -/// Response to [`BlobReadRequest`] +/// Response to [`BlobReadAtRequest`] #[derive(Serialize, Deserialize, Debug)] -pub enum BlobReadResponse { +pub enum BlobReadAtResponse { /// The entry header. Entry { /// The size of the blob @@ -1035,7 +1057,7 @@ pub enum ProviderRequest { NodeConnectionInfo(NodeConnectionInfoRequest), NodeWatch(NodeWatchRequest), - BlobRead(BlobReadRequest), + BlobReadAt(BlobReadAtRequest), BlobAddStream(BlobAddStreamRequest), BlobAddStreamUpdate(BlobAddStreamUpdate), BlobAddPath(BlobAddPathRequest), @@ -1046,6 +1068,7 @@ pub enum ProviderRequest { BlobDeleteBlob(BlobDeleteBlobRequest), BlobValidate(BlobValidateRequest), CreateCollection(CreateCollectionRequest), + BlobGetCollection(BlobGetCollectionRequest), DeleteTag(DeleteTagRequest), ListTags(ListTagsRequest), @@ -1087,7 +1110,7 @@ pub enum ProviderResponse { NodeShutdown(()), NodeWatch(NodeWatchResponse), - BlobRead(RpcResult), + BlobReadAt(RpcResult), BlobAddStream(BlobAddStreamResponse), BlobAddPath(BlobAddPathResponse), BlobDownload(DownloadProgress), @@ -1096,6 +1119,7 @@ pub enum ProviderResponse { BlobListCollections(BlobListCollectionsResponse), BlobValidate(ValidateProgress), CreateCollection(RpcResult), + BlobGetCollection(RpcResult), ListTags(ListTagsResponse), DeleteTag(RpcResult<()>),