From 395577dbd7144c7dcbe140d71d1fe23b6e8d3807 Mon Sep 17 00:00:00 2001 From: Nathan Van Gheem Date: Mon, 17 Jul 2023 17:14:31 -0500 Subject: [PATCH] fix b/w compat chitchat implementation (#1109) --- .../common/cluster/discovery/chitchat.py | 40 +++++++++++++++++++ .../common/cluster/discovery/test_chitchat.py | 8 +++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/nucliadb/nucliadb/common/cluster/discovery/chitchat.py b/nucliadb/nucliadb/common/cluster/discovery/chitchat.py index 3f8a8be3e2..39c414d18c 100644 --- a/nucliadb/nucliadb/common/cluster/discovery/chitchat.py +++ b/nucliadb/nucliadb/common/cluster/discovery/chitchat.py @@ -20,10 +20,12 @@ from __future__ import annotations import asyncio +from enum import Enum from typing import Optional import pydantic from fastapi import APIRouter, FastAPI, Response +from nucliadb_protos.writer_pb2 import Member from uvicorn.config import Config # type: ignore from uvicorn.server import Server # type: ignore @@ -39,10 +41,47 @@ api_router = APIRouter() +class MemberType(str, Enum): + IO = "Io" + SEARCH = "Search" + INGEST = "Ingest" + TRAIN = "Train" + UNKNOWN = "Unknown" + + @staticmethod + def from_pb(node_type: Member.Type.ValueType): + if node_type == Member.Type.IO: + return MemberType.IO + elif node_type == Member.Type.SEARCH: + return MemberType.SEARCH + elif node_type == Member.Type.INGEST: + return MemberType.INGEST + elif node_type == Member.Type.TRAIN: + return MemberType.TRAIN + elif node_type == Member.Type.UNKNOWN: + return MemberType.UNKNOWN + else: + raise ValueError(f"incompatible node type '{node_type}'") + + def to_pb(self) -> Member.Type.ValueType: + if self == MemberType.IO: + return Member.Type.IO + elif self == MemberType.SEARCH: + return Member.Type.SEARCH + elif self == MemberType.INGEST: + return Member.Type.INGEST + elif self == MemberType.TRAIN: + return Member.Type.TRAIN + else: + return Member.Type.UNKNOWN + + class ClusterMember(pydantic.BaseModel): node_id: str = pydantic.Field(alias="id") listen_addr: str = pydantic.Field(alias="address") shard_count: Optional[int] + type: MemberType = MemberType.UNKNOWN + is_self: bool = False class Config: allow_population_by_field_name = True @@ -59,6 +98,7 @@ async def api_update_members(members: list[ClusterMember]) -> Response: shard_count=member.shard_count or 0, ) for member in members + if not member.is_self and member.type == MemberType.IO ] ) return Response(status_code=204) diff --git a/nucliadb/nucliadb/tests/integration/common/cluster/discovery/test_chitchat.py b/nucliadb/nucliadb/tests/integration/common/cluster/discovery/test_chitchat.py index 2135a542ee..cc19470062 100644 --- a/nucliadb/nucliadb/tests/integration/common/cluster/discovery/test_chitchat.py +++ b/nucliadb/nucliadb/tests/integration/common/cluster/discovery/test_chitchat.py @@ -47,7 +47,13 @@ def make_client_fixture(): async def test_chitchat_monitor(chitchat_monitor_client): INDEX_NODES.clear() async with chitchat_monitor_client() as client: - member = dict(node_id=f"node", listen_addr=f"10.0.0.0", shard_count=20) + member = dict( + node_id=f"node", + listen_addr=f"10.0.0.0", + shard_count=20, + type="Io", + is_self=False, + ) response = await client.patch("/members", json=[member]) assert response.status_code == 204 assert len(INDEX_NODES) == 1