Skip to content

Commit

Permalink
feat: add scale to Flow and Pod (#3816)
Browse files Browse the repository at this point in the history
Co-authored-by: bwanglzu <bo.wang@jina.ai>
  • Loading branch information
JoanFM and bwanglzu committed Nov 15, 2021
1 parent 29cb116 commit df57ce1
Show file tree
Hide file tree
Showing 19 changed files with 613 additions and 11 deletions.
4 changes: 4 additions & 0 deletions jina/excepts.py
Expand Up @@ -25,6 +25,10 @@ class RuntimeFailToStart(SystemError, BaseJinaExeception):
"""When pea/pod is failed to started."""


class ScalingFails(SystemError, BaseJinaExeception):
"""When scaling is unsuccessful for an Executor."""


class MemoryOverHighWatermark(Exception, BaseJinaExeception):
"""When the memory usage is over the defined high water mark."""

Expand Down
22 changes: 22 additions & 0 deletions jina/flow/base.py
Expand Up @@ -1827,6 +1827,28 @@ def rolling_update(
any_event_loop=True,
)

def scale(
self,
pod_name: str,
replicas: int,
):
"""
Scale the amount of replicas of a given Executor.
:param pod_name: pod to update
:param replicas: The number of replicas to scale to
"""

# TODO when replicas-host is ready, needs to be passed here

from ..helper import run_async

run_async(
self._pod_nodes[pod_name].scale,
replicas=replicas,
any_event_loop=True,
)

@property
def client_args(self) -> argparse.Namespace:
"""Get Client settings.
Expand Down
2 changes: 1 addition & 1 deletion jina/parsers/peapods/pea.py
Expand Up @@ -76,7 +76,7 @@ def mixin_pea_parser(parser):
gp.add_argument(
'--replica-id',
type=int,
default=0, # not sure how to mantain backwards compatibility with the workspace of Executor
default=0,
help='the id of the replica of an executor'
if _SHOW_ALL_ARGS
else argparse.SUPPRESS,
Expand Down
93 changes: 92 additions & 1 deletion jina/peapods/pods/__init__.py
Expand Up @@ -10,6 +10,7 @@
from ..peas import BasePea
from ... import __default_executor__
from ... import helper
from ...excepts import RuntimeFailToStart, RuntimeRunForeverEarlyError, ScalingFails
from ...enums import (
SchedulerType,
PodRoleType,
Expand Down Expand Up @@ -108,6 +109,16 @@ async def rolling_update(self, *args, **kwargs):
"""
...

@abstractmethod
async def scale(self, *args, **kwargs):
"""
Scale the amount of replicas of a given Executor.
.. # noqa: DAR201
.. # noqa: DAR101
"""
...

@staticmethod
def _set_upload_files(args):
# sets args.upload_files at the pod level so that peas inherit from it.
Expand Down Expand Up @@ -331,7 +342,7 @@ class Pod(BasePod):

class _ReplicaSet:
def __init__(self, pod_args: Namespace, args: List[Namespace]):
self.pod_args = pod_args
self.pod_args = copy.copy(pod_args)
self.args = args
self._peas = []

Expand Down Expand Up @@ -361,6 +372,7 @@ def wait_start_success(self):
async def rolling_update(
self, dump_path: Optional[str] = None, *, uses_with: Optional[Dict] = None
):
# TODO make rolling_update robust, in what state this ReplicaSet ends when this fails?
for i in range(len(self._peas)):
old_pea = self._peas[i]
old_pea.close()
Expand All @@ -374,8 +386,77 @@ async def rolling_update(
new_pea.__enter__()
await new_pea.async_wait_start_success()
new_pea.activate_runtime()
self.args[i] = _args
self._peas[i] = new_pea

async def _scale_up(self, replicas: int):
new_peas = []
new_args_list = []
for i in range(len(self._peas), replicas):
new_args = copy.copy(self.args[0])
new_args.noblock_on_start = True
new_args.name = new_args.name[:-1] + f'{i}'
new_args.port_ctrl = helper.random_port()
new_args.replica_id = i
# no exception should happen at create and enter time
new_peas.append(BasePea(new_args).start())
new_args_list.append(new_args)
exception = None
for new_pea, new_args in zip(new_peas, new_args_list):
try:
await new_pea.async_wait_start_success()
except (
RuntimeFailToStart,
TimeoutError,
RuntimeRunForeverEarlyError,
) as ex:
exception = ex
break

if exception is not None:
# close peas and remove them from exitfifo
if self.pod_args.shards > 1:
msg = f' Scaling fails for shard {self.pod_args.shard_id}'
else:
msg = ' Scaling fails'

msg += f'due to executor failing to start with exception: {exception!r}'
raise ScalingFails(msg)
else:
for new_pea, new_args in zip(new_peas, new_args_list):
new_pea.activate_runtime()
self.args.append(new_args)
self._peas.append(new_pea)

async def _scale_down(self, replicas: int):
for i in reversed(range(replicas, len(self._peas))):
# Close returns exception, but in theory `termination` should handle close properly
try:
self._peas[i].close()
finally:
# If there is an exception at close time. Most likely the pea's terminated abruptly and therefore these
# peas are useless
del self._peas[i]
del self.args[i]

async def scale(self, replicas: int):
"""
Scale the amount of replicas of a given Executor.
:param replicas: The number of replicas to scale to
.. note: Scale is either successful or not. If one replica fails to start, the ReplicaSet remains in the same state
"""
# TODO make scale robust, in what state this ReplicaSet ends when this fails?
assert replicas > 0
if replicas > len(self._peas):
await self._scale_up(replicas)
elif replicas < len(self._peas):
await self._scale_down(
replicas
) # scale down has some challenges with the exit fifo
self.pod_args.replicas = replicas

def __enter__(self):
for _args in self.args:
if getattr(self.pod_args, 'noblock_on_start', False):
Expand Down Expand Up @@ -691,6 +772,16 @@ async def rolling_update(
except:
raise

async def scale(self, replicas: int):
"""
Scale the amount of replicas of a given Executor.
:param replicas: The number of replicas to scale to
"""
# potential exception will be raised to CompoundPod or Flow
await self.replica_set.scale(replicas)
self.args.replicas = replicas

@staticmethod
def _set_peas_args(
args: Namespace,
Expand Down
23 changes: 23 additions & 0 deletions jina/peapods/pods/compound.py
Expand Up @@ -287,6 +287,29 @@ async def rolling_update(
if not task.done():
task.cancel()

async def scale(self, replicas: int):
"""
Scale the amount of replicas of a given Executor.
:param replicas: The number of replicas to scale to
"""
tasks = []
try:
import asyncio

tasks = [
asyncio.create_task(shard.scale(replicas=replicas))
for shard in self.shards
]
for future in asyncio.as_completed(tasks):
_ = await future
except:
# TODO: Handle the failure of one of the shards. Unscale back all of them to the original state? Cancelling would potentially be dangerous.
for task in tasks:
if not task.done():
task.cancel()
raise

@property
def _mermaid_str(self) -> List[str]:
"""String that will be used to represent the Pod graphically when `Flow.plot()` is invoked
Expand Down
84 changes: 84 additions & 0 deletions jina/peapods/pods/k8s.py
Expand Up @@ -256,6 +256,63 @@ async def wait_restart_success(self, previous_uids: Iterable[str] = None):
fail_msg += f': {repr(exception_to_raise)}'
raise RuntimeFailToStart(fail_msg)

async def wait_scale_success(self, replicas: int):
scale_to = replicas
_timeout = self.common_args.timeout_ready
if _timeout <= 0:
_timeout = None
else:
_timeout /= 1e3

import asyncio
from kubernetes import client

with JinaLogger(f'waiting_scale_for_{self.name}') as logger:
logger.info(
f'🏝️\n\t\tWaiting for "{self.name}" to be scaled, with {self.num_replicas} replicas,'
f'scale to {scale_to}.'
)
timeout_ns = 1000000000 * _timeout if _timeout else None
now = time.time_ns()
exception_to_raise = None
while timeout_ns is None or time.time_ns() - now < timeout_ns:
try:
api_response = kubernetes_client.K8sClients().apps_v1.read_namespaced_deployment(
name=self.dns_name, namespace=self.k8s_namespace
)
logger.debug(
f'\n\t\t Scaled replicas: {api_response.status.ready_replicas}.'
f' Replicas: {api_response.status.replicas}.'
f' Expected Replicas {scale_to}'
)
if (
api_response.status.ready_replicas is not None
and api_response.status.ready_replicas == scale_to
):
logger.success(
f' {self.name} has all its replicas updated!!'
)
return
else:
scaled_replicas = api_response.status.ready_replicas or 0
if scaled_replicas < scale_to:
logger.debug(
f'\nNumber of replicas {scaled_replicas}, waiting for {scale_to - scaled_replicas} replicas to be scaled up.'
)
else:
logger.debug(
f'\nNumber of replicas {scaled_replicas}, waiting for {scaled_replicas - scale_to} replicas to be scaled down.'
)

await asyncio.sleep(1.0)
except client.ApiException as ex:
exception_to_raise = ex
break
fail_msg = f' Deployment {self.name} did not restart with a timeout of {self.common_args.timeout_ready}'
if exception_to_raise:
fail_msg += f': {repr(exception_to_raise)}'
raise RuntimeFailToStart(fail_msg)

def rolling_update(
self, dump_path: Optional[str] = None, *, uses_with: Optional[Dict] = None
):
Expand All @@ -271,6 +328,14 @@ def rolling_update(
self.deployment_args.dump_path = dump_path
self._restart_runtime()

def scale(self, replicas: int):
"""
Scale the amount of replicas of a given Executor.
:param replicas: The number of replicas to scale to
"""
self._patch_namespaced_deployment_scale(replicas)

def start(self):
with JinaLogger(f'start_{self.name}') as logger:
logger.debug(f'\t\tDeploying "{self.name}"')
Expand Down Expand Up @@ -311,6 +376,13 @@ def _read_namespaced_deployment(self):
name=self.dns_name, namespace=self.k8s_namespace
)

def _patch_namespaced_deployment_scale(self, replicas: int):
kubernetes_client.K8sClients().apps_v1.patch_namespaced_deployment_scale(
self.dns_name,
namespace=self.k8s_namespace,
body={'spec': {'replicas': replicas}},
)

def get_pod_uids(self) -> List[str]:
"""Get the UIDs for all Pods in this deployment
Expand Down Expand Up @@ -473,6 +545,18 @@ async def rolling_update(
for deployment in self.k8s_deployments:
await deployment.wait_restart_success(old_uids[deployment.dns_name])

async def scale(self, replicas: int):
"""
Scale the amount of replicas of a given Executor.
:param replicas: The number of replicas to scale to
"""
for deployment in self.k8s_deployments:
deployment.scale(replicas=replicas)
for deployment in self.k8s_deployments:
await deployment.wait_scale_success(replicas=replicas)
deployment.num_replicas = replicas

def start(self) -> 'K8sPod':
"""Deploy the kubernetes pods via k8s Deployment and k8s Service.
Expand Down
@@ -1,15 +1,16 @@
import pytest
import time
import os

from jina import Flow, Executor, Document, DocumentArray, requests
import pytest

from jina import Flow, Document, DocumentArray

cur_dir = os.path.dirname(os.path.abspath(__file__))

img_name = 'jina/replica-exec'


@pytest.fixture(scope='module')
@pytest.fixture(scope='function')
def docker_image_built():
import docker

Expand All @@ -26,7 +27,7 @@ def docker_image_built():
@pytest.mark.parametrize('replicas', [1, 3, 4])
def test_containerruntime_args(docker_image_built, shards, replicas):
f = Flow().add(
name='executor',
name='executor_container',
uses=f'docker://{img_name}',
replicas=replicas,
shards=shards,
Expand All @@ -51,8 +52,9 @@ def test_containerruntime_args(docker_image_built, shards, replicas):
for doc in r.docs:
assert doc.tags['shards'] == shards

assert shard_ids == set(range(shards))

if replicas > 1:
assert replica_ids == set(range(replicas))
else:
assert replica_ids == {-1.0}
assert shard_ids == set(range(shards))
2 changes: 1 addition & 1 deletion tests/integration/hub_usage/dummyhub/Dockerfile
Expand Up @@ -5,4 +5,4 @@ COPY . /workspace/

WORKDIR /workspace

ENTRYPOINT ["jina", "pod", "--uses", "config.yml"]
ENTRYPOINT ["jina", "executor", "--uses", "config.yml"]
2 changes: 1 addition & 1 deletion tests/integration/hub_usage/dummyhub_abs/Dockerfile
Expand Up @@ -5,4 +5,4 @@ COPY . /workspace/

WORKDIR /workspace

ENTRYPOINT ["jina", "pod", "--uses", "config.yml"]
ENTRYPOINT ["jina", "executor", "--uses", "config.yml"]
2 changes: 1 addition & 1 deletion tests/integration/hub_usage/dummyhub_pretrained/Dockerfile
Expand Up @@ -5,4 +5,4 @@ COPY . /workspace/

WORKDIR /workspace

ENTRYPOINT ["jina", "pod", "--uses", "config.yml"]
ENTRYPOINT ["jina", "executor", "--uses", "config.yml"]
2 changes: 1 addition & 1 deletion tests/integration/hub_usage/dummyhub_slow/Dockerfile
Expand Up @@ -5,4 +5,4 @@ COPY . /workspace/

WORKDIR /workspace

ENTRYPOINT ["jina", "pod", "--uses", "config.yml"]
ENTRYPOINT ["jina", "executor", "--uses", "config.yml"]

0 comments on commit df57ce1

Please sign in to comment.