diff --git a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py index a458a5e43..f9081d0d6 100644 --- a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py +++ b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py @@ -71,16 +71,23 @@ class AsyncMultiRangeDownloader: client, bucket_name="chandrasiri-rs", object_name="test_open9" ) my_buff1 = open('my_fav_file.txt', 'wb') + my_buff1 = open('my_fav_file.txt', 'wb') my_buff2 = BytesIO() my_buff3 = BytesIO() my_buff4 = any_object_which_provides_BytesIO_like_interface() + results_arr, error_obj = await mrd.download_ranges( + my_buff4 = any_object_which_provides_BytesIO_like_interface() results_arr, error_obj = await mrd.download_ranges( [ + # (start_byte, bytes_to_read, writeable_buffer) # (start_byte, bytes_to_read, writeable_buffer) (0, 100, my_buff1), (100, 20, my_buff2), (200, 123, my_buff3), (300, 789, my_buff4), + (100, 20, my_buff2), + (200, 123, my_buff3), + (300, 789, my_buff4), ] ) if error_obj: @@ -94,6 +101,17 @@ class AsyncMultiRangeDownloader: for result in results_arr: print("downloaded bytes", result) + if error_obj: + print("Error occurred: ") + print(error_obj) + print( + "please issue call to `download_ranges` with updated" + "`read_ranges` based on diff of (bytes_requested - bytes_written)" + ) + + for result in results_arr: + print("downloaded bytes", result) + """ @@ -165,7 +183,8 @@ def __init__( self.object_name = object_name self.generation_number = generation_number self.read_handle = read_handle - self.read_obj_str: _AsyncReadObjectStream = None + self.read_obj_str: Optional[_AsyncReadObjectStream] = None + self._is_stream_open: bool = False async def open(self) -> None: """Opens the bidi-gRPC connection to read from the object. @@ -176,14 +195,19 @@ async def open(self) -> None: "Opening" constitutes fetching object metadata such as generation number and read handle and sets them as attributes if not already set. """ - self.read_obj_str = _AsyncReadObjectStream( - client=self.client, - bucket_name=self.bucket_name, - object_name=self.object_name, - generation_number=self.generation_number, - read_handle=self.read_handle, - ) + if self._is_stream_open: + raise ValueError("Underlying bidi-gRPC stream is already open") + + if self.read_obj_str is None: + self.read_obj_str = _AsyncReadObjectStream( + client=self.client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation_number, + read_handle=self.read_handle, + ) await self.read_obj_str.open() + self._is_stream_open = True if self.generation_number is None: self.generation_number = self.read_obj_str.generation_number self.read_handle = self.read_obj_str.read_handle @@ -206,11 +230,15 @@ async def download_ranges( to a requested range. """ + if len(read_ranges) > 1000: raise ValueError( "Invalid input - length of read_ranges cannot be more than 1000" ) + if not self._is_stream_open: + raise ValueError("Underlying bidi-gRPC stream is not open") + read_id_to_writable_buffer_dict = {} results = [] for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST): @@ -255,4 +283,18 @@ async def download_ranges( del read_id_to_writable_buffer_dict[ object_data_range.read_range.read_id ] + return results + + async def close(self): + """ + Closes the underlying bidi-gRPC connection. + """ + if not self._is_stream_open: + raise ValueError("Underlying bidi-gRPC stream is not open") + await self.read_obj_str.close() + self._is_stream_open = False + + @property + def is_stream_open(self) -> bool: + return self._is_stream_open diff --git a/tests/unit/asyncio/test_async_multi_range_downloader.py b/tests/unit/asyncio/test_async_multi_range_downloader.py index b57bc92ca..27d1ed6dd 100644 --- a/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -30,6 +30,14 @@ class TestAsyncMultiRangeDownloader: + def create_read_ranges(self, num_ranges): + ranges = [] + for i in range(num_ranges): + ranges.append( + _storage_v2.ReadRange(read_offset=i, read_length=1, read_id=i) + ) + return ranges + # helper method @pytest.mark.asyncio async def _make_mock_mrd( @@ -76,6 +84,16 @@ async def test_create_mrd( read_handle=_TEST_READ_HANDLE, ) + mrd.read_obj_str.open.assert_called_once() + # Assert + mock_cls_async_read_object_stream.assert_called_once_with( + client=mock_grpc_client, + bucket_name=_TEST_BUCKET_NAME, + object_name=_TEST_OBJECT_NAME, + generation_number=_TEST_GENERATION_NUMBER, + read_handle=_TEST_READ_HANDLE, + ) + mrd.read_obj_str.open.assert_called_once() assert mrd.client == mock_grpc_client @@ -83,6 +101,7 @@ async def test_create_mrd( assert mrd.object_name == _TEST_OBJECT_NAME assert mrd.generation_number == _TEST_GENERATION_NUMBER assert mrd.read_handle == _TEST_READ_HANDLE + assert mrd.is_stream_open @mock.patch( "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" @@ -131,14 +150,6 @@ async def test_download_ranges( assert results[0].bytes_written == 18 assert buffer.getvalue() == b"these_are_18_chars" - def create_read_ranges(self, num_ranges): - ranges = [] - for i in range(num_ranges): - ranges.append( - _storage_v2.ReadRange(read_offset=i, read_length=1, read_id=i) - ) - return ranges - @mock.patch( "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @@ -160,3 +171,83 @@ async def test_downloading_ranges_with_more_than_1000_should_throw_error( str(exc.value) == "Invalid input - length of read_ranges cannot be more than 1000" ) + + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" + ) + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + ) + @pytest.mark.asyncio + async def test_opening_mrd_more_than_once_should_throw_error( + self, mock_grpc_client, mock_cls_async_read_object_stream + ): + # Arrange + mrd = await self._make_mock_mrd( + mock_grpc_client, mock_cls_async_read_object_stream + ) # mock mrd is already opened + + # Act + Assert + with pytest.raises(ValueError) as exc: + await mrd.open() + + # Assert + assert str(exc.value) == "Underlying bidi-gRPC stream is already open" + + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" + ) + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + ) + @pytest.mark.asyncio + async def test_close_mrd(self, mock_grpc_client, mock_cls_async_read_object_stream): + # Arrange + mrd = await self._make_mock_mrd( + mock_grpc_client, mock_cls_async_read_object_stream + ) # mock mrd is already opened + mrd.read_obj_str.close = AsyncMock() + + # Act + await mrd.close() + + # Assert + assert not mrd.is_stream_open + + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + ) + @pytest.mark.asyncio + async def test_close_mrd_not_opened_should_throw_error(self, mock_grpc_client): + # Arrange + mrd = AsyncMultiRangeDownloader( + mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME + ) + + # Act + Assert + with pytest.raises(ValueError) as exc: + await mrd.close() + + # Assert + assert str(exc.value) == "Underlying bidi-gRPC stream is not open" + assert not mrd.is_stream_open + + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + ) + @pytest.mark.asyncio + async def test_downloading_without_opening_should_throw_error( + self, mock_grpc_client + ): + # Arrange + mrd = AsyncMultiRangeDownloader( + mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME + ) + + # Act + Assert + with pytest.raises(ValueError) as exc: + await mrd.download_ranges([(0, 18, BytesIO())]) + + # Assert + assert str(exc.value) == "Underlying bidi-gRPC stream is not open" + assert not mrd.is_stream_open