diff --git a/src/bentoml/_internal/cloud/base.py b/src/bentoml/_internal/cloud/base.py index d1173f71e31..107252c77c0 100644 --- a/src/bentoml/_internal/cloud/base.py +++ b/src/bentoml/_internal/cloud/base.py @@ -5,6 +5,7 @@ from abc import ABC from abc import abstractmethod from contextlib import contextmanager +from tempfile import SpooledTemporaryFile from rich.console import Group from rich.panel import Panel @@ -24,6 +25,60 @@ from ..tag import Tag FILE_CHUNK_SIZE = 100 * 1024 * 1024 # 100Mb +SPOOLED_FILE_MAX_SIZE = 5 * 1024 * 1024 * 1024 # 5GB + + +class CallbackSpooledTemporaryFileIO(SpooledTemporaryFile): + """ + A SpooledTemporaryFile wrapper that calls + a callback when read/write is called + """ + + read_cb: t.Callable[[int], None] | None + write_cb: t.Callable[[int], None] | None + + def __init__( + self, + max_size: int = 0, + *, + read_cb: t.Callable[[int], None] | None = None, + write_cb: t.Callable[[int], None] | None = None, + ): + self.read_cb = read_cb + self.write_cb = write_cb + super().__init__(max_size) + + def read(self, *args): + res = super().read(*args) + if self.read_cb is not None: + self.read_cb(len(res)) + return res + + def write(self, s): + res = super().write(s) + if self.write_cb is not None: + if hasattr(s, "__len__"): + self.write_cb(len(s)) + return res + + def size(self) -> int: + """ + get the size of the file + """ + current_pos = self.tell() + self.seek(0, 2) + file_size = self.tell() + self.seek(current_pos) + return file_size + + def chunk(self, start: int, end: int) -> bytes: + """ + chunk the file slice of [start, end) + """ + self.seek(start) + if end < 0 or start > end: + return self.read() + return self.read(end - start) class CallbackIOWrapper(io.BytesIO): diff --git a/src/bentoml/_internal/cloud/bentocloud.py b/src/bentoml/_internal/cloud/bentocloud.py index 169c6469e34..218f1914f67 100644 --- a/src/bentoml/_internal/cloud/bentocloud.py +++ b/src/bentoml/_internal/cloud/bentocloud.py @@ -26,7 +26,9 @@ from ..tag import Tag from ..utils import calc_dir_size from .base import FILE_CHUNK_SIZE +from .base import SPOOLED_FILE_MAX_SIZE from .base import CallbackIOWrapper +from .base import CallbackSpooledTemporaryFileIO from .base import CloudClient from .config import get_rest_api_client from .deployment import Deployment @@ -667,7 +669,9 @@ def io_cb(x: int): with io_mutex: self.transmission_progress.update(upload_task_id, advance=x) - with CallbackIOWrapper(read_cb=io_cb) as tar_io: + with CallbackSpooledTemporaryFileIO( + SPOOLED_FILE_MAX_SIZE, read_cb=io_cb + ) as tar_io: with self.spin(text=f'Creating tar archive for model "{model.tag}"..'): with tarfile.open(fileobj=tar_io, mode="w:") as tar: tar.add(model.path, arcname="./") @@ -676,7 +680,7 @@ def io_cb(x: int): yatai_rest_client.start_upload_model( model_repository_name=model_repository.name, version=version ) - file_size = tar_io.getbuffer().nbytes + file_size = tar_io.size() self.transmission_progress.update( upload_task_id, description=f'Uploading model "{model.tag}"', @@ -751,15 +755,14 @@ def chunk_upload( text=f'({chunk_number}/{chunks_count}) Uploading chunk of model "{model.tag}"...' ): chunk = ( - tar_io.getbuffer()[ - (chunk_number - 1) - * FILE_CHUNK_SIZE : chunk_number - * FILE_CHUNK_SIZE - ] + tar_io.chunk( + (chunk_number - 1) * FILE_CHUNK_SIZE, + chunk_number * FILE_CHUNK_SIZE, + ) if chunk_number < chunks_count - else tar_io.getbuffer()[ - (chunk_number - 1) * FILE_CHUNK_SIZE : - ] + else tar_io.chunk( + (chunk_number - 1) * FILE_CHUNK_SIZE, -1 + ) ) with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io: