From 77f7d662cf8f87df8eb66be3222458e80f2f481a Mon Sep 17 00:00:00 2001 From: Heiru Wu Date: Thu, 28 Sep 2023 18:41:50 +0800 Subject: [PATCH] fix(clients): fix metadata overwrite --- instill/clients/connector.py | 24 ++++++++-------- instill/clients/mgmt.py | 24 ++++++++-------- instill/clients/model.py | 56 ++++++++++++++++++------------------ instill/clients/pipeline.py | 22 +++++++------- 4 files changed, 63 insertions(+), 63 deletions(-) diff --git a/instill/clients/connector.py b/instill/clients/connector.py index 1276208..f14f2b6 100644 --- a/instill/clients/connector.py +++ b/instill/clients/connector.py @@ -20,21 +20,20 @@ class ConnectorClient(Client): def __init__(self, namespace: str) -> None: - self.hosts = defaultdict(dict) + self.hosts: defaultdict = defaultdict(dict) self.instance = "default" self.namespace = namespace - self.metadata: str = "" if global_config.hosts is not None: for instance, config in global_config.hosts.items(): if not config.secure: - self.metadata = ( + channel = grpc.insecure_channel(config.url) + self.hosts[instance]["metadata"] = ( ( "authorization", f"Bearer {config.token}", ), ) - channel = grpc.insecure_channel(config.url) else: ssl_creds = grpc.ssl_channel_credentials() call_creds = grpc.access_token_call_credentials(config.token) @@ -43,6 +42,7 @@ def __init__(self, namespace: str) -> None: target=config.url, credentials=creds, ) + self.hosts[instance]["metadata"] = "" self.hosts[instance]["token"] = config.token self.hosts[instance]["channel"] = channel self.hosts[instance][ @@ -107,7 +107,7 @@ def create_connector( request=connector_interface.CreateUserConnectorResourceRequest( connector_resource=connector, parent=self.namespace ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) return resp.connector_resource @@ -120,7 +120,7 @@ def get_connector(self, name: str) -> connector_interface.ConnectorResource: request=connector_interface.GetUserConnectorResourceRequest( name=f"{self.namespace}/connector-resources/{name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .connector_resource ) @@ -133,7 +133,7 @@ def test_connector(self, name: str) -> connector_interface.ConnectorResource.Sta request=connector_interface.TestUserConnectorResourceRequest( name=f"{self.namespace}/connector-resources/{name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .state ) @@ -146,7 +146,7 @@ def execute_connector(self, name: str, inputs: list) -> list: request=connector_interface.ExecuteUserConnectorResourceRequest( name=f"{self.namespace}/connector-resources/{name}", inputs=inputs ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .outputs ) @@ -159,7 +159,7 @@ def watch_connector(self, name: str) -> connector_interface.ConnectorResource.St request=connector_interface.WatchUserConnectorResourceRequest( name=f"{self.namespace}/connector-resources/{name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .state ) @@ -170,7 +170,7 @@ def delete_connector(self, name: str): request=connector_interface.DeleteUserConnectorResourceRequest( name=f"{self.namespace}/connector-resources/{name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) @grpc_handler @@ -180,12 +180,12 @@ def list_connectors(self, public=False) -> Tuple[list, str, int]: request=connector_interface.ListUserConnectorResourcesRequest( parent=self.namespace ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) else: resp = self.hosts[self.instance]["client"].ListConnectorResources( request=connector_interface.ListConnectorResourcesRequest(), - metadata=(self.metadata,), + metadata=(self.hosts[self.instance]["metadata"],), ) return resp.connector_resources, resp.next_page_token, resp.total_size diff --git a/instill/clients/mgmt.py b/instill/clients/mgmt.py index 341ccd6..529b124 100644 --- a/instill/clients/mgmt.py +++ b/instill/clients/mgmt.py @@ -16,20 +16,19 @@ class MgmtClient(Client): def __init__(self) -> None: - self.hosts = defaultdict(dict) + self.hosts: defaultdict = defaultdict(dict) self.instance: str = "default" - self.metadata: str = "" if global_config.hosts is not None: for instance, config in global_config.hosts.items(): if not config.secure: - self.metadata = ( + channel = grpc.insecure_channel(config.url) + self.hosts[instance]["metadata"] = ( ( "authorization", f"Bearer {config.token}", ), ) - channel = grpc.insecure_channel(config.url) else: ssl_creds = grpc.ssl_channel_credentials() call_creds = grpc.access_token_call_credentials(config.token) @@ -38,6 +37,7 @@ def __init__(self) -> None: target=config.url, credentials=creds, ) + self.hosts[instance]["metadata"] = "" self.hosts[instance]["token"] = config.token self.hosts[instance]["channel"] = channel self.hosts[instance]["client"] = mgmt_service.MgmtPublicServiceStub( @@ -103,7 +103,7 @@ def login(self, username="admin", password="password") -> str: def get_token(self, name: str) -> mgmt_interface.ApiToken: response = self.hosts[self.instance]["client"].GetToken( request=mgmt_interface.GetTokenRequest(name=name), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) return response.token @@ -111,7 +111,7 @@ def get_token(self, name: str) -> mgmt_interface.ApiToken: def get_user(self) -> mgmt_interface.User: response = self.hosts[self.instance]["client"].QueryAuthenticatedUser( request=mgmt_interface.QueryAuthenticatedUserRequest(), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) return response.user @@ -121,7 +121,7 @@ def list_pipeline_trigger_records( ) -> metric_interface.ListPipelineTriggerRecordsResponse: return self.hosts[self.instance]["client"].ListPipelineTriggerRecords( request=metric_interface.ListPipelineTriggerChartRecordsRequest(), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) @grpc_handler @@ -130,7 +130,7 @@ def list_pipeline_trigger_table_records( ) -> metric_interface.ListPipelineTriggerTableRecordsRequest: return self.hosts[self.instance]["client"].ListPipelineTriggerRecords( request=metric_interface.ListPipelineTriggerTableRecordsResponse(), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) @grpc_handler @@ -139,7 +139,7 @@ def list_pipeline_trigger_chart_records( ) -> metric_interface.ListPipelineTriggerChartRecordsResponse: return self.hosts[self.instance]["client"].ListPipelineTriggerRecords( request=metric_interface.ListPipelineTriggerChartRecordsRequest(), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) @grpc_handler @@ -148,7 +148,7 @@ def list_connector_execute_records( ) -> metric_interface.ListConnectorExecuteRecordsResponse: return self.hosts[self.instance]["client"].ListPipelineTriggerRecords( request=metric_interface.ListConnectorExecuteRecordsRequest(), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) @grpc_handler @@ -157,7 +157,7 @@ def list_connector_execute_table_records( ) -> metric_interface.ListConnectorExecuteTableRecordsResponse: return self.hosts[self.instance]["client"].ListPipelineTriggerRecords( request=metric_interface.ListConnectorExecuteTableRecordsRequest(), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) @grpc_handler @@ -166,5 +166,5 @@ def list_connector_execute_chart_records( ) -> metric_interface.ListConnectorExecuteChartRecordsResponse: return self.hosts[self.instance]["client"].ListPipelineTriggerRecords( request=metric_interface.ListConnectorExecuteChartRecordsRequest(), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) diff --git a/instill/clients/model.py b/instill/clients/model.py index 453c6a9..bd386f9 100644 --- a/instill/clients/model.py +++ b/instill/clients/model.py @@ -19,21 +19,20 @@ class ModelClient(Client): def __init__(self, namespace: str) -> None: - self.hosts = defaultdict(dict) + self.hosts: defaultdict = defaultdict(dict) self.instance: str = "default" self.namespace: str = namespace - self.metadata: str = "" if global_config.hosts is not None: for instance, config in global_config.hosts.items(): if not config.secure: - self.metadata = ( + channel = grpc.insecure_channel(config.url) + self.hosts[instance]["metadata"] = ( ( "authorization", f"Bearer {config.token}", ), ) - channel = grpc.insecure_channel(config.url) else: ssl_creds = grpc.ssl_channel_credentials() call_creds = grpc.access_token_call_credentials(config.token) @@ -42,6 +41,7 @@ def __init__(self, namespace: str) -> None: target=config.url, credentials=creds, ) + self.hosts[instance]["metadata"] = "" self.hosts[instance]["token"] = config.token self.hosts[instance]["channel"] = channel self.hosts[instance]["client"] = model_service.ModelPublicServiceStub( @@ -99,7 +99,7 @@ def watch_model(self, model_name: str) -> model_interface.Model.State: request=model_interface.WatchUserModelRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .state ) @@ -123,7 +123,7 @@ def create_model_local( ) resp = self.hosts[self.instance]["client"].CreateUserModelBinaryFileUpload( request_iterator=iter([req]), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) while ( @@ -132,7 +132,7 @@ def create_model_local( request=model_interface.GetModelOperationRequest( name=resp.operation.name ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .operation.done is not True @@ -143,7 +143,7 @@ def create_model_local( request=model_interface.WatchUserModelRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) while watch_resp.state == 0: time.sleep(1) @@ -151,7 +151,7 @@ def create_model_local( request=model_interface.WatchUserModelRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) if watch_resp.state == 1: @@ -159,7 +159,7 @@ def create_model_local( self.hosts[self.instance]["client"] .GetUserModel( request=model_interface.GetUserModelRequest(name=model_name), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .model ) @@ -181,7 +181,7 @@ def create_model( request=model_interface.CreateUserModelRequest( model=model, parent=self.namespace ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) while ( @@ -190,7 +190,7 @@ def create_model( request=model_interface.GetModelOperationRequest( name=resp.operation.name ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .operation.done is not True @@ -205,7 +205,7 @@ def create_model( request=model_interface.WatchUserModelRequest( name=f"{self.namespace}/models/{model.id}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) while watch_resp.state == 0: time.sleep(1) @@ -213,7 +213,7 @@ def create_model( request=model_interface.WatchUserModelRequest( name=f"{self.namespace}/models/{model.id}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) if watch_resp.state == 1: @@ -223,7 +223,7 @@ def create_model( request=model_interface.GetUserModelRequest( name=f"{self.namespace}/models/{model.id}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .model ) @@ -236,14 +236,14 @@ def deploy_model(self, model_name: str) -> model_interface.Model.State: request=model_interface.DeployUserModelRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) watch_resp = self.hosts[self.instance]["client"].WatchUserModel( request=model_interface.WatchUserModelRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) while watch_resp.state not in (2, 3): time.sleep(1) @@ -251,7 +251,7 @@ def deploy_model(self, model_name: str) -> model_interface.Model.State: request=model_interface.WatchUserModelRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) return watch_resp.state @@ -262,14 +262,14 @@ def undeploy_model(self, model_name: str) -> model_interface.Model.State: request=model_interface.UndeployUserModelRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) watch_resp = self.hosts[self.instance]["client"].WatchUserModel( request=model_interface.WatchUserModelRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) while watch_resp.state not in (1, 3): time.sleep(1) @@ -277,7 +277,7 @@ def undeploy_model(self, model_name: str) -> model_interface.Model.State: request=model_interface.WatchUserModelRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) return watch_resp.state @@ -288,7 +288,7 @@ def trigger_model(self, model_name: str, task_inputs: list) -> list: request=model_interface.TriggerUserModelRequest( name=f"{self.namespace}/models/{model_name}", task_inputs=task_inputs ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) return resp.task_outputs @@ -298,7 +298,7 @@ def delete_model(self, model_name: str): request=model_interface.DeleteUserModelRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) @grpc_handler @@ -309,7 +309,7 @@ def get_model(self, model_name: str) -> model_interface.Model: request=model_interface.GetUserModelRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .model ) @@ -322,7 +322,7 @@ def get_model_by_uid(self, model_uid: str) -> model_interface.Model: request=model_interface.LookUpModelRequest( permalink=f"models/{model_uid}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .model ) @@ -335,7 +335,7 @@ def get_model_card(self, model_name: str) -> model_interface.Model: request=model_interface.GetUserModelCardRequest( name=f"{self.namespace}/models/{model_name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .readme ) @@ -345,12 +345,12 @@ def list_models(self, public=False) -> Tuple[list, str, int]: if not public: resp = self.hosts[self.instance]["client"].ListUserModels( request=model_interface.ListUserModelsRequest(parent=self.namespace), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) else: resp = self.hosts[self.instance]["client"].ListModels( request=model_interface.ListModelsRequest(), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) return resp.models, resp.next_page_token, resp.total_size diff --git a/instill/clients/pipeline.py b/instill/clients/pipeline.py index cd2cd19..1558605 100644 --- a/instill/clients/pipeline.py +++ b/instill/clients/pipeline.py @@ -20,21 +20,20 @@ class PipelineClient(Client): def __init__(self, namespace: str) -> None: - self.hosts = defaultdict(dict) + self.hosts: defaultdict = defaultdict(dict) self.instance: str = "default" self.namespace: str = namespace - self.metadata: str = "" if global_config.hosts is not None: for instance, config in global_config.hosts.items(): if not config.secure: - self.metadata = ( + channel = grpc.insecure_channel(config.url) + self.hosts[instance]["metadata"] = ( ( "authorization", f"Bearer {config.token}", ), ) - channel = grpc.insecure_channel(config.url) else: ssl_creds = grpc.ssl_channel_credentials() call_creds = grpc.access_token_call_credentials(config.token) @@ -43,6 +42,7 @@ def __init__(self, namespace: str) -> None: target=config.url, credentials=creds, ) + self.hosts[instance]["metadata"] = "" self.hosts[instance]["token"] = config.token self.hosts[instance]["channel"] = channel self.hosts[instance][ @@ -106,7 +106,7 @@ def create_pipeline( request=pipeline_interface.CreateUserPipelineRequest( pipeline=pipeline, parent=self.namespace ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) return resp.pipeline @@ -118,7 +118,7 @@ def get_pipeline(self, name: str) -> pipeline_interface.Pipeline: request=pipeline_interface.GetUserPipelineRequest( name=f"{self.namespace}/pipelines/{name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .pipeline ) @@ -131,7 +131,7 @@ def validate_pipeline(self, name: str) -> pipeline_interface.Pipeline: request=pipeline_interface.ValidateUserPipelineRequest( name=f"{self.namespace}/pipelines/{name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) .pipeline ) @@ -144,7 +144,7 @@ def trigger_pipeline( request=pipeline_interface.TriggerUserPipelineRequest( name=f"{self.namespace}/pipelines/{name}", inputs=inputs ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) return resp.outputs, resp.metadata @@ -154,7 +154,7 @@ def delete_pipeline(self, name: str): request=pipeline_interface.DeleteUserPipelineRequest( name=f"{self.namespace}/pipelines/{name}" ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) @grpc_handler @@ -164,12 +164,12 @@ def list_pipelines(self, public=False) -> Tuple[list, str, int]: request=pipeline_interface.ListUserPipelinesRequest( parent=self.namespace ), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) else: resp = self.hosts[self.instance]["client"].ListPipelines( request=pipeline_interface.ListPipelinesRequest(), - metadata=self.metadata, + metadata=self.hosts[self.instance]["metadata"], ) return resp.pipelines, resp.next_page_token, resp.total_size