Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,21 @@

from __future__ import annotations

import logging
import pathlib
import threading

import grpc
import ni.measurementlink.pinmap.v1.pin_map_service_pb2 as pin_map_service_pb2
import ni.measurementlink.pinmap.v1.pin_map_service_pb2_grpc as pin_map_service_pb2_grpc
from ni.measurementlink.discovery.v1.client import DiscoveryClient as DiscoveryClient
from ni.measurementlink.discovery.v1.client import DiscoveryClient
from ni.measurementlink.pinmap.v1 import pin_map_service_pb2_grpc
from ni_grpc_extensions.channelpool import GrpcChannelPool

_logger = logging.getLogger(__name__)
from ni.measurementlink.pinmap.v1.client._client_base import GrpcServiceClientBase

GRPC_SERVICE_INTERFACE_NAME = "ni.measurementlink.pinmap.v1.PinMapService"
GRPC_SERVICE_CLASS = "ni.measurementlink.pinmap.v1.PinMapService"

class PinMapClient(GrpcServiceClientBase[pin_map_service_pb2_grpc.PinMapServiceStub]):
"""Client for accessing the NI Pin Map Service via gRPC."""

class PinMapClient:
"""Client for accessing the NI Pin Map Service."""
__slots__ = ()

def __init__(
self,
Expand All @@ -37,41 +34,14 @@ def __init__(

grpc_channel_pool: An optional gRPC channel pool (recommended).
"""
self._initialization_lock = threading.Lock()
self._discovery_client = discovery_client
self._grpc_channel_pool = grpc_channel_pool
self._stub: pin_map_service_pb2_grpc.PinMapServiceStub | None = None

if grpc_channel is not None:
self._stub = pin_map_service_pb2_grpc.PinMapServiceStub(grpc_channel)

def _get_stub(self) -> pin_map_service_pb2_grpc.PinMapServiceStub:
if self._stub is None:
with self._initialization_lock:
if self._grpc_channel_pool is None:
_logger.debug("Creating unshared GrpcChannelPool.")
self._grpc_channel_pool = GrpcChannelPool()
if self._discovery_client is None:
_logger.debug("Creating unshared DiscoveryClient.")
self._discovery_client = DiscoveryClient(
grpc_channel_pool=self._grpc_channel_pool
)
if self._stub is None:
compute_nodes = self._discovery_client.enumerate_compute_nodes()
remote_compute_nodes = [node for node in compute_nodes if not node.is_local]
# Use remote node URL as deployment target if only one remote node is found.
# If more than one remote node exists, use empty string for deployment target.
first_remote_node_url = (
remote_compute_nodes[0].url if len(remote_compute_nodes) == 1 else ""
)
service_location = self._discovery_client.resolve_service(
provided_interface=GRPC_SERVICE_INTERFACE_NAME,
deployment_target=first_remote_node_url,
service_class=GRPC_SERVICE_CLASS,
)
channel = self._grpc_channel_pool.get_channel(service_location.insecure_address)
self._stub = pin_map_service_pb2_grpc.PinMapServiceStub(channel)
return self._stub
super().__init__(
discovery_client=discovery_client,
grpc_channel=grpc_channel,
grpc_channel_pool=grpc_channel_pool,
service_interface_name="ni.measurementlink.pinmap.v1.PinMapService",
service_class="ni.measurementlink.pinmap.v1.PinMapService",
stub_class=pin_map_service_pb2_grpc.PinMapServiceStub,
)

def update_pin_map(self, pin_map_path: str | pathlib.Path) -> str:
"""Update registered pin map contents.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

import logging
import threading
from typing import Generic, Protocol, TypeVar

import grpc
from ni.measurementlink.discovery.v1.client import DiscoveryClient
from ni_grpc_extensions.channelpool import GrpcChannelPool

_logger = logging.getLogger(__name__)


class StubProtocol(Protocol):
"""Protocol for gRPC stub classes."""

def __init__(self, channel: grpc.Channel) -> None:
"""Initialize the gRPC client."""


TStub = TypeVar("TStub", bound=StubProtocol)


class GrpcServiceClientBase(Generic[TStub]):
"""Base class for NI gRPC service clients."""

__slots__ = (
"_initialization_lock",
"_discovery_client",
"_grpc_channel_pool",
"_stub",
"_service_interface_name",
"_service_class",
"_stub_class",
)

_initialization_lock: threading.Lock
_discovery_client: DiscoveryClient | None
_grpc_channel_pool: GrpcChannelPool | None
_stub: TStub | None
_service_interface_name: str
_service_class: str
_stub_class: type[TStub]

def __init__(
self,
service_interface_name: str,
service_class: str,
stub_class: type[TStub],
*,
discovery_client: DiscoveryClient | None = None,
grpc_channel: grpc.Channel | None = None,
grpc_channel_pool: GrpcChannelPool | None = None,
) -> None:
"""Initialize the gRPC client.

Args:
service_interface_name: The fully qualified name of the service interface.
service_class: The name of the service class.
stub_class: The gRPC stub class for the service.
discovery_client: An optional discovery client (recommended).
grpc_channel: An optional pin map gRPC channel.
grpc_channel_pool: An optional gRPC channel pool (recommended).
"""
self._initialization_lock = threading.Lock()
self._discovery_client = discovery_client
self._grpc_channel_pool = grpc_channel_pool
self._stub = stub_class(grpc_channel) if grpc_channel is not None else None
self._service_interface_name = service_interface_name
self._service_class = service_class
self._stub_class = stub_class

def _get_stub(self) -> TStub:
if self._stub is None:
with self._initialization_lock:
if self._grpc_channel_pool is None:
_logger.debug("Creating unshared GrpcChannelPool.")
self._grpc_channel_pool = GrpcChannelPool()

if self._discovery_client is None:
_logger.debug("Creating unshared DiscoveryClient.")
self._discovery_client = DiscoveryClient(
grpc_channel_pool=self._grpc_channel_pool
)

if self._stub is None:
compute_nodes = self._discovery_client.enumerate_compute_nodes()
remote_nodes = [node for node in compute_nodes if not node.is_local]
target_url = remote_nodes[0].url if len(remote_nodes) == 1 else ""

service_location = self._discovery_client.resolve_service(
provided_interface=self._service_interface_name,
deployment_target=target_url,
service_class=self._service_class,
)

channel = self._grpc_channel_pool.get_channel(service_location.insecure_address)
self._stub = self._stub_class(channel)

return self._stub
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import logging
import threading
import warnings
from collections.abc import Iterable, Mapping

Expand All @@ -14,6 +13,9 @@
from ni.measurementlink.discovery.v1.client import DiscoveryClient
from ni_grpc_extensions.channelpool import GrpcChannelPool

from ni.measurementlink.sessionmanagement.v1.client._client_base import (
GrpcServiceClientBase,
)
from ni.measurementlink.sessionmanagement.v1.client._constants import (
GRPC_SERVICE_CLASS,
GRPC_SERVICE_INTERFACE_NAME,
Expand All @@ -32,8 +34,12 @@
_logger = logging.getLogger(__name__)


class SessionManagementClient:
"""Client for accessing the measurement plug-in session management service."""
class SessionManagementClient(
GrpcServiceClientBase[session_management_service_pb2_grpc.SessionManagementServiceStub]
):
"""Client for accessing the NI Session Management Service via gRPC."""

__slots__ = ()

def __init__(
self,
Expand All @@ -42,54 +48,15 @@ def __init__(
grpc_channel: grpc.Channel | None = None,
grpc_channel_pool: GrpcChannelPool | None = None,
) -> None:
"""Initialize session management client.

Args:
discovery_client: An optional discovery client (recommended).

grpc_channel: An optional session management gRPC channel.

grpc_channel_pool: An optional gRPC channel pool (recommended).
"""
self._initialization_lock = threading.Lock()
self._discovery_client = discovery_client
self._grpc_channel_pool = grpc_channel_pool
self._stub: session_management_service_pb2_grpc.SessionManagementServiceStub | None = None

if grpc_channel is not None:
self._stub = session_management_service_pb2_grpc.SessionManagementServiceStub(
grpc_channel
)

def _get_stub(self) -> session_management_service_pb2_grpc.SessionManagementServiceStub:
if self._stub is None:
with self._initialization_lock:
if self._grpc_channel_pool is None:
_logger.debug("Creating unshared GrpcChannelPool.")
self._grpc_channel_pool = GrpcChannelPool()
if self._discovery_client is None:
_logger.debug("Creating unshared DiscoveryClient.")
self._discovery_client = DiscoveryClient(
grpc_channel_pool=self._grpc_channel_pool
)
if self._stub is None:
compute_nodes = self._discovery_client.enumerate_compute_nodes()
remote_compute_nodes = [node for node in compute_nodes if not node.is_local]
# Use remote node URL as deployment target if only one remote node is found.
# If more than one remote node exists, use empty string for deployment target.
first_remote_node_url = (
remote_compute_nodes[0].url if len(remote_compute_nodes) == 1 else ""
)
service_location = self._discovery_client.resolve_service(
provided_interface=GRPC_SERVICE_INTERFACE_NAME,
deployment_target=first_remote_node_url,
service_class=GRPC_SERVICE_CLASS,
)
channel = self._grpc_channel_pool.get_channel(service_location.insecure_address)
self._stub = session_management_service_pb2_grpc.SessionManagementServiceStub(
channel
)
return self._stub
"""Initialize a SessionManagementClient instance."""
super().__init__(
discovery_client=discovery_client,
grpc_channel=grpc_channel,
grpc_channel_pool=grpc_channel_pool,
service_interface_name=GRPC_SERVICE_INTERFACE_NAME,
service_class=GRPC_SERVICE_CLASS,
stub_class=session_management_service_pb2_grpc.SessionManagementServiceStub,
)

def reserve_session(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

import logging
import threading
from typing import Generic, Protocol, TypeVar

import grpc
from ni.measurementlink.discovery.v1.client import DiscoveryClient
from ni_grpc_extensions.channelpool import GrpcChannelPool

_logger = logging.getLogger(__name__)


class StubProtocol(Protocol):
"""Protocol for gRPC stub classes."""

def __init__(self, channel: grpc.Channel) -> None:
"""Initialize the gRPC client."""


TStub = TypeVar("TStub", bound=StubProtocol)


class GrpcServiceClientBase(Generic[TStub]):
"""Base class for NI gRPC service clients."""

__slots__ = (
"_initialization_lock",
"_discovery_client",
"_grpc_channel_pool",
"_stub",
"_service_interface_name",
"_service_class",
"_stub_class",
)

_initialization_lock: threading.Lock
_discovery_client: DiscoveryClient | None
_grpc_channel_pool: GrpcChannelPool | None
_stub: TStub | None
_service_interface_name: str
_service_class: str
_stub_class: type[TStub]

def __init__(
self,
service_interface_name: str,
service_class: str,
stub_class: type[TStub],
*,
discovery_client: DiscoveryClient | None = None,
grpc_channel: grpc.Channel | None = None,
grpc_channel_pool: GrpcChannelPool | None = None,
) -> None:
"""Initialize the gRPC client.

Args:
service_interface_name: The fully qualified name of the service interface.
service_class: The name of the service class.
stub_class: The gRPC stub class for the service.
discovery_client: An optional discovery client (recommended).
grpc_channel: An optional pin map gRPC channel.
grpc_channel_pool: An optional gRPC channel pool (recommended).
"""
self._initialization_lock = threading.Lock()
self._discovery_client = discovery_client
self._grpc_channel_pool = grpc_channel_pool
self._stub = stub_class(grpc_channel) if grpc_channel is not None else None
self._service_interface_name = service_interface_name
self._service_class = service_class
self._stub_class = stub_class

def _get_stub(self) -> TStub:
if self._stub is None:
with self._initialization_lock:
if self._grpc_channel_pool is None:
_logger.debug("Creating unshared GrpcChannelPool.")
self._grpc_channel_pool = GrpcChannelPool()

if self._discovery_client is None:
_logger.debug("Creating unshared DiscoveryClient.")
self._discovery_client = DiscoveryClient(
grpc_channel_pool=self._grpc_channel_pool
)

if self._stub is None:
compute_nodes = self._discovery_client.enumerate_compute_nodes()
remote_nodes = [node for node in compute_nodes if not node.is_local]
target_url = remote_nodes[0].url if len(remote_nodes) == 1 else ""

service_location = self._discovery_client.resolve_service(
provided_interface=self._service_interface_name,
deployment_target=target_url,
service_class=self._service_class,
)

channel = self._grpc_channel_pool.get_channel(service_location.insecure_address)
self._stub = self._stub_class(channel)

return self._stub
Loading
Loading