Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ def get_grpc_device_server_location(
)
return service_location

compute_nodes = discovery_client.enumerate_compute_nodes()
remote_compute_nodes = [node for node in compute_nodes if not node.is_local]
# Use remote node URL as deployment target if only one remote node is found.
# If more than one remote node exists, use empty string for deployment target.
first_remote_node_url = remote_compute_nodes[0].url if len(remote_compute_nodes) == 1 else ""
service_location = discovery_client.resolve_service(
provided_interface=provided_interface,
deployment_target=first_remote_node_url,
service_class=SERVICE_CLASS,
)
_logger.debug(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Public API for accessing the NI Discovery Service."""

from ni_measurement_plugin_sdk_service.discovery._client import DiscoveryClient
from ni_measurement_plugin_sdk_service.discovery._types import ServiceLocation
from ni_measurement_plugin_sdk_service.discovery._types import (
ComputeNodeDescriptor,
ServiceLocation,
)

__all__ = ["DiscoveryClient", "ServiceLocation"]
__all__ = ["DiscoveryClient", "ServiceLocation", "ComputeNodeDescriptor"]
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from ni_measurement_plugin_sdk_service.discovery._support import (
_get_discovery_service_address,
)
from ni_measurement_plugin_sdk_service.discovery._types import ServiceLocation
from ni_measurement_plugin_sdk_service.discovery._types import (
ComputeNodeDescriptor,
ServiceLocation,
)
from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool
from ni_measurement_plugin_sdk_service.measurement.info import (
MeasurementInfo,
Expand Down Expand Up @@ -304,3 +307,20 @@ def enumerate_services(self, provided_interface: str) -> Sequence[ServiceInfo]:
response = self._get_stub().EnumerateServices(request)

return [ServiceInfo._from_grpc(service) for service in response.available_services]

def enumerate_compute_nodes(self) -> Sequence[ComputeNodeDescriptor]:
"""Enumerates all the compute nodes registered with the discovery service.

Returns:
The list of information describing the compute nodes.
"""
request = discovery_service_pb2.EnumerateComputeNodesRequest()

try:
response = self._get_stub().EnumerateComputeNodes(request)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.UNIMPLEMENTED:
return []
raise

return [ComputeNodeDescriptor._from_grpc(node) for node in response.compute_nodes]
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,22 @@ def _from_grpc(cls, other: discovery_service_pb2.ServiceLocation) -> ServiceLoca
insecure_port=other.insecure_port,
ssl_authenticated_port=other.ssl_authenticated_port,
)


class ComputeNodeDescriptor(typing.NamedTuple):
"""Represents a compute node."""

url: str
"""The resolvable name (URL) of the compute node."""

is_local: bool
"""Indicates whether the compute node is local node."""

@classmethod
def _from_grpc(
cls, other: discovery_service_pb2.ComputeNodeDescriptor
) -> ComputeNodeDescriptor:
return ComputeNodeDescriptor(
url=other.url,
is_local=other.is_local,
)
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,16 @@ def _get_stub(self) -> pin_map_service_pb2_grpc.PinMapServiceStub:
grpc_channel_pool=self._grpc_channel_pool
)
if self._stub is None:
compute_nodes = self._discovery_client.enumerate_compute_nodes()
remote_compute_nodes = [node for node in compute_nodes if not node.is_local]
# Use remote node URL as deployment target if only one remote node is found.
# If more than one remote node exists, use empty string for deployment target.
first_remote_node_url = (
remote_compute_nodes[0].url if len(remote_compute_nodes) == 1 else ""
)
service_location = self._discovery_client.resolve_service(
provided_interface=GRPC_SERVICE_INTERFACE_NAME,
deployment_target=first_remote_node_url,
service_class=GRPC_SERVICE_CLASS,
)
channel = self._grpc_channel_pool.get_channel(service_location.insecure_address)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,16 @@ def _get_stub(self) -> session_management_service_pb2_grpc.SessionManagementServ
grpc_channel_pool=self._grpc_channel_pool
)
if self._stub is None:
compute_nodes = self._discovery_client.enumerate_compute_nodes()
remote_compute_nodes = [node for node in compute_nodes if not node.is_local]
# Use remote node URL as deployment target if only one remote node is found.
# If more than one remote node exists, use empty string for deployment target.
first_remote_node_url = (
remote_compute_nodes[0].url if len(remote_compute_nodes) == 1 else ""
)
service_location = self._discovery_client.resolve_service(
provided_interface=GRPC_SERVICE_INTERFACE_NAME,
deployment_target=first_remote_node_url,
service_class=GRPC_SERVICE_CLASS,
)
channel = self._grpc_channel_pool.get_channel(service_location.insecure_address)
Expand Down
2 changes: 2 additions & 0 deletions packages/service/tests/unit/_drivers/test_grpcdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test___default_configuration___get_grpc_device_server_location___resolves_se

discovery_client.resolve_service.assert_called_with(
provided_interface=fake_driver.GRPC_SERVICE_INTERFACE_NAME,
deployment_target="",
service_class=SERVICE_CLASS,
)
assert service_location == ServiceLocation("localhost", "1234", "")
Expand Down Expand Up @@ -136,6 +137,7 @@ def test___default_configuration___get_insecure_grpc_device_server_channel___res

discovery_client.resolve_service.assert_called_with(
provided_interface=fake_driver.GRPC_SERVICE_INTERFACE_NAME,
deployment_target="",
service_class=SERVICE_CLASS,
)
grpc_channel_pool.get_channel.assert_called_with("localhost:1234")
Expand Down
74 changes: 73 additions & 1 deletion packages/service/tests/unit/test_discovery_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
SERVICE_PROGRAMMINGLANGUAGE_KEY,
)
from ni_measurement_plugin_sdk_service._internal.stubs.ni.measurementlink.discovery.v1.discovery_service_pb2 import (
ComputeNodeDescriptor as GrpcComputeNodeDescriptor,
EnumerateComputeNodesResponse,
EnumerateServicesRequest,
EnumerateServicesResponse,
RegisterServiceRequest,
Expand All @@ -34,7 +36,11 @@
from ni_measurement_plugin_sdk_service._internal.stubs.ni.measurementlink.discovery.v1.discovery_service_pb2_grpc import (
DiscoveryServiceStub,
)
from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient, ServiceLocation
from ni_measurement_plugin_sdk_service.discovery import (
ComputeNodeDescriptor,
DiscoveryClient,
ServiceLocation,
)
from ni_measurement_plugin_sdk_service.discovery._support import (
_get_discovery_service_address,
_open_key_file,
Expand Down Expand Up @@ -459,6 +465,71 @@ def test___no_registered_measurements___resolve_service_with_information___raise
assert exc_info.value.code() == grpc.StatusCode.NOT_FOUND


def test___registered_compute_node___enumerate_compute_nodes___returns_node(
discovery_client: DiscoveryClient, discovery_service_stub: Mock
):
expected_node = ComputeNodeDescriptor(url="http://remotehost:42000", is_local=False)
discovery_service_stub.EnumerateComputeNodes.return_value = EnumerateComputeNodesResponse(
compute_nodes=[
GrpcComputeNodeDescriptor(url=expected_node.url, is_local=expected_node.is_local),
]
)

compute_nodes = discovery_client.enumerate_compute_nodes()

discovery_service_stub.EnumerateComputeNodes.assert_called_once()
assert compute_nodes[0] == expected_node


def test___multiple_registered_compute_nodes___enumerate_compute_nodes___returns_all_nodes(
discovery_client: DiscoveryClient, discovery_service_stub: Mock
):
expected_nodes = [
ComputeNodeDescriptor(url="http://localhost:42000", is_local=True),
ComputeNodeDescriptor(url="http://remotehost:42001", is_local=False),
]
discovery_service_stub.EnumerateComputeNodes.return_value = EnumerateComputeNodesResponse(
compute_nodes=[
GrpcComputeNodeDescriptor(
url=expected_nodes[0].url, is_local=expected_nodes[0].is_local
),
GrpcComputeNodeDescriptor(
url=expected_nodes[1].url, is_local=expected_nodes[1].is_local
),
]
)

compute_nodes = discovery_client.enumerate_compute_nodes()

discovery_service_stub.EnumerateComputeNodes.assert_called_once()
assert compute_nodes == expected_nodes


def test___no_registered_compute_nodes___enumerate_compute_nodes___returns_empty_list(
discovery_client: DiscoveryClient, discovery_service_stub: Mock
):
discovery_service_stub.EnumerateComputeNodes.return_value = EnumerateComputeNodesResponse()

compute_nodes = discovery_client.enumerate_compute_nodes()

discovery_service_stub.EnumerateComputeNodes.assert_called_once()
assert compute_nodes == []


def test___enumerate_compute_nodes___grpc_error___raises_rpc_error(
discovery_client: DiscoveryClient, discovery_service_stub: Mock
):
discovery_service_stub.EnumerateComputeNodes.side_effect = FakeRpcError(
grpc.StatusCode.UNAVAILABLE, details="Test service unavailable"
)

with pytest.raises(grpc.RpcError) as exc_info:
discovery_client.enumerate_compute_nodes()

discovery_service_stub.EnumerateComputeNodes.assert_called_once()
assert exc_info.value.code() == grpc.StatusCode.UNAVAILABLE


@pytest.fixture(scope="module")
def subprocess_popen_kwargs() -> dict[str, Any]:
kwargs: dict[str, Any] = {}
Expand Down Expand Up @@ -488,6 +559,7 @@ def discovery_service_stub(mocker: MockerFixture) -> Mock:
stub.EnumerateServices = mocker.create_autospec(grpc.UnaryUnaryMultiCallable)
stub.ResolveService = mocker.create_autospec(grpc.UnaryUnaryMultiCallable)
stub.ResolveServiceWithInformation = mocker.create_autospec(grpc.UnaryUnaryMultiCallable)
stub.EnumerateComputeNodes = mocker.create_autospec(grpc.UnaryUnaryMultiCallable)
return stub


Expand Down