Skip to content
This repository has been archived by the owner on Sep 22, 2023. It is now read-only.

Commit

Permalink
feat: Allow overriding storage proxy address mapping (#194)
Browse files Browse the repository at this point in the history
Co-authored-by: Joongi Kim <joongi@lablup.com>
  • Loading branch information
fregataa and achimnol committed Mar 14, 2022
1 parent cbfe074 commit 5bba387
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 13 deletions.
1 change: 1 addition & 0 deletions changes/194.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add storage proxy address overriding option
28 changes: 28 additions & 0 deletions src/ai/backend/client/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from decimal import Decimal
from typing import (
Any,
Mapping,
Union,
Optional,
)

Expand Down Expand Up @@ -55,6 +57,32 @@ def convert(self, value, param, ctx):
return value


class CommaSeparatedKVListParamType(click.ParamType):
name = "comma-seperated-KVList-check"

def convert(self, value: Union[str, Mapping[str, str]], param, ctx) -> Mapping[str, str]:
if isinstance(value, dict):
return value
if not isinstance(value, str):
self.fail(
f"expected string, got {value!r} of type {type(value).__name__}",
param, ctx,
)
override_map = {}
for assignment in value.split(","):
try:
k, _, v = assignment.partition("=")
if k == '' or v == '':
raise ValueError(f"key or value is empty. key = {k}, value = {v}")
except ValueError:
self.fail(
f"{value!r} is not a valid mapping expression", param, ctx,
)
else:
override_map[k] = v
return override_map


class JSONParamType(click.ParamType):
"""
A JSON string parameter type.
Expand Down
22 changes: 18 additions & 4 deletions src/ai/backend/client/cli/vfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from tqdm import tqdm

from ai.backend.cli.interaction import ask_yn
from ai.backend.client.config import DEFAULT_CHUNK_SIZE
from ai.backend.client.config import DEFAULT_CHUNK_SIZE, APIConfig
from ai.backend.client.session import Session

from ..compat import asyncio_run
from ..session import AsyncSession
from .main import main
from .pretty import print_done, print_error, print_fail, print_info, print_wait, print_warn
from .params import ByteSizeParamType, ByteSizeParamCheckType
from .params import ByteSizeParamType, ByteSizeParamCheckType, CommaSeparatedKVListParamType


@main.group()
Expand Down Expand Up @@ -180,7 +180,13 @@ def info(name):
help='Transfer the file with the given chunk size with binary suffixes (e.g., "16m"). '
'Set this between 8 to 64 megabytes for high-speed disks (e.g., SSD RAID) '
'and networks (e.g., 40 GbE) for the maximum throughput.')
def upload(name, filenames, base_dir, chunk_size):
@click.option('--override-storage-proxy',
type=CommaSeparatedKVListParamType(), default=None,
help='Overrides storage proxy address. '
'The value must shape like "X1=Y1,X2=Y2...". '
'Each Yn address must at least include the IP address '
'or the hostname and may include the protocol part and the port number to replace.')
def upload(name, filenames, base_dir, chunk_size, override_storage_proxy):
'''
TUS Upload a file to the virtual folder from the current working directory.
The files with the same names will be overwirtten.
Expand All @@ -196,6 +202,7 @@ def upload(name, filenames, base_dir, chunk_size):
basedir=base_dir,
chunk_size=chunk_size,
show_progress=True,
address_map=override_storage_proxy or APIConfig.DEFAULTS['storage_proxy_address_map'],
)
print_done('Done.')
except Exception as e:
Expand All @@ -214,7 +221,13 @@ def upload(name, filenames, base_dir, chunk_size):
help='Transfer the file with the given chunk size with binary suffixes (e.g., "16m"). '
'Set this between 8 to 64 megabytes for high-speed disks (e.g., SSD RAID) '
'and networks (e.g., 40 GbE) for the maximum throughput.')
def download(name, filenames, base_dir, chunk_size):
@click.option('--override-storage-proxy',
type=CommaSeparatedKVListParamType(), default=None,
help='Overrides storage proxy address. '
'The value must shape like "X1=Y1,X2=Y2...". '
'Each Yn address must at least include the IP address '
'or the hostname and may include the protocol part and the port number to replace.')
def download(name, filenames, base_dir, chunk_size, override_storage_proxy):
'''
Download a file from the virtual folder to the current working directory.
The files with the same names will be overwirtten.
Expand All @@ -230,6 +243,7 @@ def download(name, filenames, base_dir, chunk_size):
basedir=base_dir,
chunk_size=chunk_size,
show_progress=True,
address_map=override_storage_proxy or APIConfig.DEFAULTS['storage_proxy_address_map'],
)
print_done('Done.')
except Exception as e:
Expand Down
52 changes: 45 additions & 7 deletions src/ai/backend/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import re
from typing import (
Any,
Callable,
Iterable,
List,
Expand Down Expand Up @@ -57,15 +58,15 @@ def parse_api_version(value: str) -> Tuple[int, str]:
T = TypeVar('T')


def default_clean(v: str) -> T:
def default_clean(v: Union[str, Mapping]) -> T:
return cast(T, v)


def get_env(
key: str,
default: Union[str, Undefined] = _undefined,
default: Union[str, Mapping, Undefined] = _undefined,
*,
clean: Callable[[str], T] = default_clean,
clean: Callable[[Any], T] = default_clean,
) -> T:
"""
Retrieves a configuration value from the environment variables.
Expand All @@ -88,8 +89,10 @@ def get_env(
if raw is None:
if default is _undefined:
raise KeyError(key)
raw = default
return clean(raw)
result = default
else:
result = raw
return clean(result)


def bool_env(v: str) -> bool:
Expand Down Expand Up @@ -120,6 +123,26 @@ def _clean_tokens(v: str) -> Tuple[str, ...]:
return tuple(v.split(','))


def _clean_address_map(v: Union[str, Mapping]) -> Mapping:
if isinstance(v, dict):
return v
if not isinstance(v, str):
raise ValueError(
f'Storage proxy address map has invalid type "{type(v)}", expected str or dict.',
)
override_map = {}
for assignment in v.split(","):
try:
k, _, v = assignment.partition("=")
if k == '' or v == '':
raise ValueError
except ValueError:
raise ValueError(f"{v} is not a valid mapping expression")
else:
override_map[k] = v
return override_map


class APIConfig:
"""
Represents a set of API client configurations.
Expand Down Expand Up @@ -157,13 +180,14 @@ class APIConfig:
<ai.backend.client.kernel.Kernel.get_or_create>` calls.
"""

DEFAULTS: Mapping[str, str] = {
DEFAULTS: Mapping[str, Union[str, Mapping]] = {
'endpoint': 'https://api.backend.ai',
'endpoint_type': 'api',
'version': f'v{API_VERSION[0]}.{API_VERSION[1]}',
'hash_type': 'sha256',
'domain': 'default',
'group': 'default',
'storage_proxy_address_map': {},
'connection_timeout': '10.0',
'read_timeout': '0',
}
Expand All @@ -183,6 +207,7 @@ def __init__(
endpoint_type: str = None,
domain: str = None,
group: str = None,
storage_proxy_address_map: Mapping[str, str] = None,
version: str = None,
user_agent: str = None,
access_key: str = None,
Expand All @@ -206,8 +231,16 @@ def __init__(
get_env('DOMAIN', self.DEFAULTS['domain'], clean=str)
self._group = group if group is not None else \
get_env('GROUP', self.DEFAULTS['group'], clean=str)
self._storage_proxy_address_map = storage_proxy_address_map \
if storage_proxy_address_map is not None else \
get_env(
'OVERRIDE_STORAGE_PROXY',
self.DEFAULTS['storage_proxy_address_map'],
# The shape of this env var must be like "X1=Y1,X2=Y2"
clean=_clean_address_map,
)
self._version = version if version is not None else \
self.DEFAULTS['version']
default_clean(self.DEFAULTS['version'])
self._user_agent = user_agent if user_agent is not None else get_user_agent()
if self._endpoint_type == 'api':
self._access_key = access_key if access_key is not None else \
Expand Down Expand Up @@ -278,6 +311,11 @@ def group(self) -> str:
"""The configured group."""
return self._group

@property
def storage_proxy_address_map(self) -> Mapping[str, str]:
"""The storage proxy address map for overriding."""
return self.storage_proxy_address_map

@property
def user_agent(self) -> str:
"""The configured user agent string."""
Expand Down
28 changes: 26 additions & 2 deletions src/ai/backend/client/func/vfolder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
from pathlib import Path
from typing import (
Mapping,
Optional,
Sequence,
Union,
)
Expand All @@ -17,6 +19,7 @@
from .base import api_function, BaseFunction
from ..compat import current_loop
from ..config import DEFAULT_CHUNK_SIZE, MAX_INFLIGHT_CHUNKS
from ..exceptions import BackendClientError
from ..pagination import generate_paginated_results
from ..request import Request

Expand Down Expand Up @@ -166,6 +169,7 @@ async def download(
basedir: Union[str, Path] = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
show_progress: bool = False,
address_map: Optional[Mapping[str, str]] = None,
) -> None:
base_path = (Path.cwd() if basedir is None else Path(basedir).resolve())
for relpath in relative_paths:
Expand All @@ -177,7 +181,17 @@ async def download(
})
async with rqst.fetch() as resp:
download_info = await resp.json()
download_url = URL(download_info['url']).with_query({
overriden_url = download_info['url']
if address_map:
if download_info['url'] in address_map:
overriden_url = address_map[download_info['url']]
else:
raise BackendClientError(
'Overriding storage proxy addresses are given, '
'but no url matches with any of them.\n',
)

download_url = URL(overriden_url).with_query({
'token': download_info['token'],
})

Expand Down Expand Up @@ -229,6 +243,7 @@ async def upload(
*,
basedir: Union[str, Path] = None,
chunk_size: int = DEFAULT_CHUNK_SIZE,
address_map: Optional[Mapping[str, str]] = None,
show_progress: bool = False,
) -> None:
base_path = (Path.cwd() if basedir is None else Path(basedir).resolve())
Expand All @@ -246,7 +261,16 @@ async def upload(
})
async with rqst.fetch() as resp:
upload_info = await resp.json()
upload_url = URL(upload_info['url']).with_query({
overriden_url = upload_info['url']
if address_map:
if upload_info['url'] in address_map:
overriden_url = address_map[upload_info['url']]
else:
raise BackendClientError(
'Overriding storage proxy addresses are given, '
'but no url matches with any of them.\n',
)
upload_url = URL(overriden_url).with_query({
'token': upload_info['token'],
})
tus_client = client.TusClient()
Expand Down

0 comments on commit 5bba387

Please sign in to comment.