Skip to content

Commit

Permalink
refactor: abstract the construction of different pods (#2346)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Apr 26, 2021
1 parent 9cb2330 commit f4d0893
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 27 deletions.
72 changes: 62 additions & 10 deletions cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,26 @@


def pod(args: 'Namespace'):
"""Start a Pod"""
from jina.peapods import Pod
"""
Start a Pod
:param args: arguments coming from the CLI.
"""
from jina.peapods.pods.factory import PodFactory

try:
with Pod(args) as p:
with PodFactory.build_pod(args) as p:
p.join()
except KeyboardInterrupt:
pass


def pea(args: 'Namespace'):
"""Start a Pea"""
"""
Start a Pea
:param args: arguments coming from the CLI.
"""
from jina.peapods import Pea

try:
Expand All @@ -28,31 +36,53 @@ def pea(args: 'Namespace'):


def gateway(args: 'Namespace'):
"""Start a Gateway Pod"""
"""
Start a Gateway Pod
:param args: arguments coming from the CLI.
"""
pod(args)


def check(args: 'Namespace'):
"""Check jina config, settings, imports, network etc"""
"""
Check jina config, settings, imports, network etc
:param args: arguments coming from the CLI.
"""
from jina.checker import ImportChecker

ImportChecker(args)


def ping(args: 'Namespace'):
"""
Check the connectivity of a Pea
:param args: arguments coming from the CLI.
"""
from jina.checker import NetworkChecker

NetworkChecker(args)


def client(args: 'Namespace'):
"""Start a client connects to the gateway"""
"""
Start a client connects to the gateway
:param args: arguments coming from the CLI.
"""
from jina.clients import Client

Client(args)


def export_api(args: 'Namespace'):
"""
Export the API
:param args: arguments coming from the CLI.
"""
import json
from .export import api_to_dict
from jina.jaml import JAML
Expand Down Expand Up @@ -86,12 +116,22 @@ def export_api(args: 'Namespace'):


def hello_world(args: 'Namespace'):
"""
Run the fashion hello world example
:param args: arguments coming from the CLI.
"""
from jina.helloworld.fashion import hello_world

hello_world(args)


def hello(args: 'Namespace'):
"""
Run any of the hello world examples
:param args: arguments coming from the CLI.
"""
if args.hello == 'fashion':
from jina.helloworld.fashion import hello_world
elif args.hello == 'chatbot':
Expand All @@ -105,7 +145,11 @@ def hello(args: 'Namespace'):


def flow(args: 'Namespace'):
"""Start a Flow from a YAML file or a docker image"""
"""
Start a Flow from a YAML file or a docker image
:param args: arguments coming from the CLI.
"""
from jina.flow import Flow

if args.uses:
Expand All @@ -119,14 +163,22 @@ def flow(args: 'Namespace'):


def optimizer(args: 'Namespace'):
"""Start an optimization from a YAML file"""
"""
Start an optimization from a YAML file
:param args: arguments coming from the CLI.
"""
from jina.optimizers import run_optimizer_cli

run_optimizer_cli(args)


def hub(args: 'Namespace'):
"""Start a hub builder for build, push, pull"""
"""
Start a hub builder for build, push, pull
:param args: arguments coming from the CLI.
"""
from jina.docker.hubio import HubIO

getattr(HubIO(args), args.hub)()
7 changes: 2 additions & 5 deletions jina/flow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from ..peapods import Pod
from ..peapods.pods.compoundpod import CompoundPod
from ..peapods.pods.factory import PodFactory


class FlowType(type(ExitStack), type(JAMLCompatible)):
Expand Down Expand Up @@ -306,11 +307,7 @@ def add(
parser = set_gateway_parser()

args = ArgNamespace.kwargs2namespace(kwargs, parser)

if args.replicas == 1:
op_flow._pod_nodes[pod_name] = Pod(args, needs=needs)
else:
op_flow._pod_nodes[pod_name] = CompoundPod(args, needs=needs)
op_flow._pod_nodes[pod_name] = PodFactory.build_pod(args, needs)
op_flow.last_pod = pod_name

return op_flow
Expand Down
16 changes: 9 additions & 7 deletions jina/peapods/pods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

import argparse

import copy
from abc import abstractmethod
from argparse import Namespace
Expand All @@ -26,7 +26,9 @@ class BasePod(ExitStack):
They can be also run in their own containers on remote machines.
"""

def __init__(self, args: Union['argparse.Namespace', Dict], needs: Set[str] = None):
def __init__(
self, args: Union['Namespace', Dict], needs: Optional[Set[str]] = None
):
super().__init__()
self.peas = [] # type: List['BasePea']
self.args = args
Expand Down Expand Up @@ -125,10 +127,6 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
super().__exit__(exc_type, exc_val, exc_tb)
self.join()

@staticmethod
def _set_after_to_pass(args):
raise NotImplemented

@staticmethod
def _copy_to_head_args(
args: Namespace, polling_type: PollingType, as_router: bool = True
Expand Down Expand Up @@ -256,13 +254,15 @@ def head_args(self):
.. # noqa: DAR201
"""
...

@abstractmethod
def tail_args(self):
"""Get the arguments for the `tail` of this BasePod.
.. # noqa: DAR201
"""
...


class Pod(BasePod):
Expand All @@ -272,7 +272,9 @@ class Pod(BasePod):
:param needs: pod names of preceding pods, the output of these pods are going into the input of this pod
"""

def __init__(self, args: Union['argparse.Namespace', Dict], needs: Set[str] = None):
def __init__(
self, args: Union['Namespace', Dict], needs: Optional[Set[str]] = None
):
super().__init__(args, needs)
if isinstance(args, Dict):
# This is used when a Pod is created in a remote context, where peas & their connections are already given.
Expand Down
14 changes: 9 additions & 5 deletions jina/peapods/pods/compoundpod.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

import argparse
import copy
from argparse import Namespace
from itertools import cycle
Expand All @@ -16,13 +15,18 @@


class CompoundPod(BasePod):
"""A CompoundPod is a immutable set of pods, which run in parallel. They share the same input and output socket.
Internally, the peas of the pods can run with the process/thread backend. They can be also run in their own containers.
:param args: pod arguments parsed from the CLI
"""A CompoundPod is a immutable set of pods, which run in parallel.
A CompoundPod is an abstraction using a composable pattern to abstract the usage of parallel Pods that act as replicas.
CompoundPod will make sure to add a `HeadPea` and a `TailPea` to serve as routing/merging pattern for the different Pod replicas
:param args: pod arguments parsed from the CLI. These arguments will be used for each of the replicas
:param needs: pod names of preceding pods, the output of these pods are going into the input of this pod
"""

def __init__(self, args: Union['argparse.Namespace', Dict], needs: Set[str] = None):
def __init__(
self, args: Union['Namespace', Dict], needs: Optional[Set[str]] = None
):
super().__init__(args, needs)
self.replica_list = [] # type: List['Pod']
if isinstance(args, Dict):
Expand Down
28 changes: 28 additions & 0 deletions jina/peapods/pods/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Union, Optional, Dict, Set
from argparse import Namespace

from .. import Pod
from .. import BasePod
from .compoundpod import CompoundPod


class PodFactory:
"""
A PodFactory is a factory class, abstracting the Pod creation
"""

@staticmethod
def build_pod(
args: Union['Namespace', Dict], needs: Optional[Set[str]] = None
) -> BasePod:
"""Build an implementation of a `BasePod` interface
:param args: pod arguments parsed from the CLI.
:param needs: pod names of preceding pods
:return: the created BasePod
"""
if args.replicas > 1:
return CompoundPod(args, needs=needs)
else:
return Pod(args, needs=needs)
12 changes: 12 additions & 0 deletions tests/unit/peapods/pods/test_pod_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from jina.peapods.pods.factory import PodFactory
from jina.peapods.pods import Pod
from jina.peapods.pods.compoundpod import CompoundPod
from jina.parsers import set_pod_parser


def test_pod_factory_pod():
args_no_replicas = set_pod_parser().parse_args(['--replicas', '1'])
assert isinstance(PodFactory.build_pod(args_no_replicas), Pod)

args_replicas = set_pod_parser().parse_args(['--replicas', '2'])
assert isinstance(PodFactory.build_pod(args_replicas), CompoundPod)

0 comments on commit f4d0893

Please sign in to comment.