From c1d268622998dbb3c3504a0692e7b733cb8f1e0b Mon Sep 17 00:00:00 2001 From: felix-wang <35718120+numb3r3@users.noreply.github.com> Date: Sun, 24 Apr 2022 15:09:20 +0800 Subject: [PATCH] fix: download with resume (#689) * fix: download with resume * fix: pass a valid user-agent * fix: unttest --- server/clip_server/model/clip.py | 54 +++++++++++++++++++++------ server/clip_server/model/clip_onnx.py | 8 +++- tests/test_server.py | 21 ++++++++++- 3 files changed, 68 insertions(+), 15 deletions(-) diff --git a/server/clip_server/model/clip.py b/server/clip_server/model/clip.py index 04acbd470..1dad6acc1 100644 --- a/server/clip_server/model/clip.py +++ b/server/clip_server/model/clip.py @@ -3,6 +3,7 @@ import os import io import urllib +import shutil import warnings from typing import Union, List @@ -36,7 +37,7 @@ } -def _download(url: str, root: str): +def _download(url: str, root: str, with_resume: bool = True): os.makedirs(root, exist_ok=True) filename = os.path.basename(url) @@ -70,20 +71,48 @@ def _download(url: str, root: str): task = progress.add_task('download', filename=url, start=False) - with urllib.request.urlopen(url) as source, open( - download_target, 'wb' - ) as output: + tmp_file_path = download_target + '.part' + resume_byte_pos = ( + os.path.getsize(tmp_file_path) if os.path.exists(tmp_file_path) else 0 + ) + + total_bytes = -1 + try: + # resolve the 403 error by passing a valid user-agent + req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'}) + + total_bytes = int( + urllib.request.urlopen(req).info().get('Content-Length', -1) + ) + + mode = 'ab' if (with_resume and resume_byte_pos) else 'wb' + + with open(tmp_file_path, mode) as output: + + progress.update(task, total=total_bytes) + + progress.start_task(task) - progress.update(task, total=int(source.info().get('Content-Length'))) + if resume_byte_pos and with_resume: + progress.update(task, advance=resume_byte_pos) + req.headers['Range'] = f'bytes={resume_byte_pos}-' - progress.start_task(task) - while True: - buffer = source.read(8192) - if not buffer: - break + with urllib.request.urlopen(req) as source: + while True: + buffer = source.read(8192) + if not buffer: + break - output.write(buffer) - progress.update(task, advance=len(buffer)) + output.write(buffer) + progress.update(task, advance=len(buffer)) + except Exception as ex: + raise ex + finally: + # rename the temp download file to the correct name if fully downloaded + if os.path.exists(tmp_file_path) and ( + total_bytes == os.path.getsize(tmp_file_path) + ): + shutil.move(tmp_file_path, download_target) return download_target @@ -165,6 +194,7 @@ def load( model_path = _download( _S3_BUCKET + _MODELS[name], download_root or os.path.expanduser('~/.cache/clip'), + with_resume=True, ) elif os.path.isfile(name): model_path = name diff --git a/server/clip_server/model/clip_onnx.py b/server/clip_server/model/clip_onnx.py index c5a237938..9d42efc7f 100644 --- a/server/clip_server/model/clip_onnx.py +++ b/server/clip_server/model/clip_onnx.py @@ -20,8 +20,12 @@ class CLIPOnnxModel: def __init__(self, name: str = None): if name in _MODELS: cache_dir = os.path.expanduser(f'~/.cache/clip/{name.replace("/", "-")}') - self._textual_path = _download(_S3_BUCKET + _MODELS[name][0], cache_dir) - self._visual_path = _download(_S3_BUCKET + _MODELS[name][1], cache_dir) + self._textual_path = _download( + _S3_BUCKET + _MODELS[name][0], cache_dir, with_resume=True + ) + self._visual_path = _download( + _S3_BUCKET + _MODELS[name][1], cache_dir, with_resume=True + ) else: raise RuntimeError( f'Model {name} not found; available models = {available_models()}' diff --git a/tests/test_server.py b/tests/test_server.py index d2bf5b00d..c1d99748d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,11 +1,30 @@ import os import pytest -from clip_server.model.clip import _transform_ndarray, _transform_blob +from clip_server.model.clip import _transform_ndarray, _transform_blob, _download from docarray import Document import numpy as np +def test_server_download(tmpdir): + _download('https://docarray.jina.ai/_static/favicon.png', tmpdir, with_resume=False) + + target_path = os.path.join(tmpdir, 'favicon.png') + file_size = os.path.getsize(target_path) + assert file_size > 0 + + part_path = target_path + '.part' + with open(target_path, 'rb') as source, open(part_path, 'wb') as part_out: + buf = source.read(10) + part_out.write(buf) + + os.remove(target_path) + + _download('https://docarray.jina.ai/_static/favicon.png', tmpdir, with_resume=True) + assert os.path.getsize(target_path) == file_size + assert not os.path.exists(part_path) + + @pytest.mark.parametrize( 'image_uri', [