Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions instill/clients/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from abc import ABC, abstractmethod
from typing import Union

import google.protobuf.message
import grpc


class Client(ABC):
Expand Down Expand Up @@ -49,3 +53,27 @@ def readiness(self):
@abstractmethod
def is_serving(self):
raise NotImplementedError


class RequestFactory:
def __init__(
self,
method: Union[grpc.UnaryUnaryMultiCallable, grpc.StreamUnaryMultiCallable],
request: google.protobuf.message.Message,
metadata,
) -> None:
self.method = method
self.request = request
self.metadata = metadata

def send_sync(self):
return self.method(request=self.request, metadata=self.metadata)

def send_stream(self):
return self.method(
request_iterator=iter([self.request]),
metadata=self.metadata,
)

async def send_async(self):
return await self.method(request=self.request, metadata=self.metadata)
70 changes: 25 additions & 45 deletions instill/clients/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,23 @@
from instill.utils.error_handler import NotServingException
from instill.utils.logger import Logger

_mgmt_client = None
_pipeline_client = None
_model_client = None
_client = None


def _get_mgmt_client() -> MgmtClient:
global _mgmt_client

if _mgmt_client is None:
_mgmt_client = MgmtClient()

return _mgmt_client


def _get_pipeline_client() -> PipelineClient:
global _pipeline_client

if _pipeline_client is None:
_pipeline_client = PipelineClient(namespace=_get_mgmt_client().get_user().name)

return _pipeline_client


def _get_model_client() -> ModelClient:
global _model_client

if _model_client is None:
_model_client = ModelClient(namespace=_get_mgmt_client().get_user().name)

return _model_client


class InstillClient:
def __init__(self) -> None:
self.mgmt_service = _get_mgmt_client()
def __init__(self, async_enabled: bool = False) -> None:
self.mgmt_service = MgmtClient(async_enabled=async_enabled)
if not self.mgmt_service.is_serving():
Logger.w("Instill Core is required")
raise NotServingException
self.pipeline_service = _get_pipeline_client()
self.pipeline_service = PipelineClient(
namespace=self.mgmt_service.get_user().user.name,
async_enabled=async_enabled,
)
if not self.pipeline_service.is_serving():
Logger.w("Instill VDP is not serving, VDP functionalities will not work")
self.model_service = _get_model_client()
self.model_service = ModelClient(
namespace=self.mgmt_service.get_user().user.name,
async_enabled=async_enabled,
)
if not self.model_service.is_serving():
Logger.w(
"Instill Model is not serving, Model functionalities will not work"
Expand All @@ -61,19 +35,25 @@ def set_instance(self, instance: str):
def close(self):
if self.mgmt_service.is_serving():
for host in self.mgmt_service.hosts.values():
host["channel"].close()
host.channel.close()
if self.pipeline_service.is_serving():
for host in self.pipeline_service.hosts.values():
host["channel"].close()
host.channel.close()
if self.model_service.is_serving():
for host in self.model_service.hosts.values():
host["channel"].close()
host.channel.close()

async def async_close(self):
if self.mgmt_service.is_serving():
for host in self.mgmt_service.hosts.values():
await host.async_channel.close()
if self.pipeline_service.is_serving():
for host in self.pipeline_service.hosts.values():
await host.async_channel.close()
if self.model_service.is_serving():
for host in self.model_service.hosts.values():
await host.async_channel.close()

def get_client() -> InstillClient:
global _client

if _client is None:
_client = InstillClient()

return _client
def get_client(async_enabled: bool = False) -> InstillClient:
return InstillClient(async_enabled=async_enabled)
45 changes: 45 additions & 0 deletions instill/clients/instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Union

import grpc

import instill.protogen.core.mgmt.v1alpha.mgmt_public_service_pb2_grpc as mgmt_service
import instill.protogen.model.model.v1alpha.model_public_service_pb2_grpc as model_service
import instill.protogen.vdp.pipeline.v1alpha.pipeline_public_service_pb2_grpc as pipeline_service


class InstillInstance:
def __init__(self, stub, url: str, token: str, secure: bool, async_enabled: bool):
self.url: str = url
self.token: str = token
self.async_enabled: bool = async_enabled
self.metadata: Union[str, tuple] = ""
if not secure:
channel = grpc.insecure_channel(url)
self.metadata = (
(
"authorization",
f"Bearer {token}",
),
)
if async_enabled:
async_channel = grpc.aio.insecure_channel(url)
else:
ssl_creds = grpc.ssl_channel_credentials()
call_creds = grpc.access_token_call_credentials(token)
creds = grpc.composite_channel_credentials(ssl_creds, call_creds)
channel = grpc.secure_channel(target=url, credentials=creds)
if async_enabled:
async_channel = grpc.aio.secure_channel(target=url, credentials=creds)
self.channel: grpc.Channel = channel
self.client: Union[
model_service.ModelPublicServiceStub,
pipeline_service.PipelinePublicServiceStub,
mgmt_service.MgmtPublicServiceStub,
] = stub(channel)
if async_enabled:
self.async_channel: grpc.Channel = async_channel
self.async_client: Union[
model_service.ModelPublicServiceStub,
pipeline_service.PipelinePublicServiceStub,
mgmt_service.MgmtPublicServiceStub,
] = stub(async_channel)
Loading