From 9e806389a799441ef0bcd4c80f7180bc2c106a6b Mon Sep 17 00:00:00 2001 From: Ivan Zubenko Date: Mon, 17 Jul 2023 21:58:27 +0300 Subject: [PATCH] add update kube token task (#822) --- platform_monitoring/kube_client.py | 120 ++++++++++++----------------- tests/integration/test_kube.py | 99 +++++++++++++++++------- 2 files changed, 121 insertions(+), 98 deletions(-) diff --git a/platform_monitoring/kube_client.py b/platform_monitoring/kube_client.py index 7826efb9..4301de7c 100644 --- a/platform_monitoring/kube_client.py +++ b/platform_monitoring/kube_client.py @@ -5,7 +5,7 @@ import re import ssl from collections.abc import AsyncIterator, Sequence -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from dataclasses import dataclass, field from datetime import datetime from pathlib import Path @@ -315,6 +315,7 @@ def __init__( auth_cert_key_path: Optional[str] = None, token: Optional[str] = None, token_path: Optional[str] = None, + token_update_interval_s: int = 300, conn_timeout_s: int = 300, read_timeout_s: int = 100, conn_pool_size: int = 100, @@ -333,6 +334,7 @@ def __init__( self._auth_cert_key_path = auth_cert_key_path self._token = token self._token_path = token_path + self._token_update_interval_s = token_update_interval_s self._conn_timeout_s = conn_timeout_s self._read_timeout_s = read_timeout_s @@ -344,6 +346,7 @@ def __init__( self._trace_configs = trace_configs self._client: Optional[aiohttp.ClientSession] = None + self._token_updater_task: Optional[asyncio.Task[None]] = None @property def _is_ssl(self) -> bool: @@ -363,43 +366,35 @@ def _create_ssl_context(self) -> Optional[ssl.SSLContext]: return ssl_context async def init(self) -> None: - self._client = await self.create_http_client() - - async def init_if_needed(self) -> None: - if not self._client or self._client.closed: - await self.init() - - async def create_http_client( - self, *, force_close: bool = False - ) -> aiohttp.ClientSession: connector = aiohttp.TCPConnector( - limit=self._conn_pool_size, - ssl=self._create_ssl_context(), - force_close=force_close, + limit=self._conn_pool_size, ssl=self._create_ssl_context() ) - if self._auth_type == KubeClientAuthType.TOKEN: - token = self._token - if not token: - assert self._token_path is not None - token = Path(self._token_path).read_text() - headers = {"Authorization": "Bearer " + token} - else: - headers = {} + if self._token_path: + self._token = Path(self._token_path).read_text() + self._token_updater_task = asyncio.create_task(self._start_token_updater()) timeout = aiohttp.ClientTimeout( connect=self._conn_timeout_s, total=self._read_timeout_s ) - return aiohttp.ClientSession( + self._client = aiohttp.ClientSession( connector=connector, timeout=timeout, - headers=headers, trace_configs=self._trace_configs, ) - async def _recreate_http_client(self) -> None: - logger.warning("Reloading http K8s client.") - await self.close() - self._token = None - await self.init() + async def _start_token_updater(self) -> None: + if not self._token_path: + return + while True: + try: + token = Path(self._token_path).read_text() + if token != self._token: + self._token = token + logger.info("Kube token was refreshed") + except asyncio.CancelledError: + raise + except Exception as exc: + logger.exception("Failed to update kube token: %s", exc) + await asyncio.sleep(self._token_update_interval_s) @property def namespace(self) -> str: @@ -409,6 +404,11 @@ async def close(self) -> None: if self._client: await self._client.close() self._client = None + if self._token_updater_task: + self._token_updater_task.cancel() + with suppress(asyncio.CancelledError): + await self._token_updater_task + self._token_updater_task = None async def __aenter__(self) -> "KubeClient": await self.init() @@ -460,24 +460,22 @@ def _generate_pod_log_url(self, pod_name: str, container_name: str) -> str: url = f"{url}/log?container={pod_name}&follow=true" return url + def _create_headers( + self, headers: Optional[dict[str, Any]] = None + ) -> dict[str, Any]: + headers = dict(headers) if headers else {} + if self._auth_type == KubeClientAuthType.TOKEN and self._token: + headers["Authorization"] = "Bearer " + self._token + return headers + async def _request(self, *args: Any, **kwargs: Any) -> dict[str, Any]: - await self.init_if_needed() + headers = self._create_headers(kwargs.pop("headers", None)) assert self._client, "client is not initialized" - doing_retry = kwargs.pop("doing_retry", False) - - async with self._client.request(*args, **kwargs) as response: - try: - await self._check_response_status(response) - payload = await response.json() - logger.debug("k8s response payload: %s", payload) - except KubeClientUnauthorized: - if doing_retry: - raise - await self._recreate_http_client() - kwargs["doing_retry"] = True - payload = await self._request(*args, **kwargs) - - return payload + async with self._client.request(*args, headers=headers, **kwargs) as response: + await self._check_response_status(response) + payload = await response.json() + logger.debug("k8s response payload: %s", payload) + return payload async def get_raw_pod(self, pod_name: str) -> dict[str, Any]: url = self._generate_pod_url(pod_name) @@ -534,18 +532,6 @@ async def wait_pod_is_not_waiting( return status await asyncio.sleep(interval_s) - def _get_node_proxy_url(self, host: str, port: int) -> URL: - return URL(self._generate_node_proxy_url(host, port)) - - @asynccontextmanager - async def get_node_proxy_client( - self, host: str, port: int - ) -> AsyncIterator[ProxyClient]: - assert self._client - yield ProxyClient( - url=self._get_node_proxy_url(host, port), session=self._client - ) - async def get_pod_container_stats( self, node_name: str, pod_name: str, container_name: str ) -> Optional["PodContainerStats"]: @@ -570,26 +556,20 @@ async def get_pod_container_gpu_stats( node_name: str, pod_name: str, container_name: str, - doing_retry: bool = False, ) -> Optional["PodContainerGPUStats"]: url = self._generate_node_gpu_metrics_url(node_name) if not url: return None try: - await self.init_if_needed() assert self._client - async with self._client.get(url, raise_for_status=True) as resp: + async with self._client.get( + url, headers=self._create_headers(), raise_for_status=True + ) as resp: text = await resp.text() gpu_counters = GPUCounters.parse(text) return gpu_counters.get_pod_container_stats( self._namespace, pod_name, container_name ) - except aiohttp.ClientResponseError as e: - if e.status == 401 and not doing_retry: - await self._recreate_http_client() - return await self.get_pod_container_gpu_stats( - node_name, pod_name, container_name, doing_retry=True, - ) except aiohttp.ClientError as e: logger.exception(e) return None @@ -624,9 +604,9 @@ async def create_pod_container_logs_stream( client_timeout = aiohttp.ClientTimeout( connect=conn_timeout_s, sock_read=read_timeout_s ) - await self.init_if_needed() - async with self._client.get( # type: ignore - url, timeout=client_timeout + assert self._client + async with self._client.get( + url, headers=self._create_headers(), timeout=client_timeout ) as response: await self._check_response_status(response, job_id=pod_name) yield response.content @@ -678,9 +658,7 @@ def _assert_resource_kind( elif kind != expected_kind: raise ValueError(f"unknown kind: {kind}") - def _raise_for_status( - self, payload: dict[str, Any], job_id: Optional[str] - ) -> None: + def _raise_for_status(self, payload: dict[str, Any], job_id: Optional[str]) -> None: if payload["code"] == 400: if "ContainerCreating" in payload["message"]: raise JobNotFoundException(f"job '{job_id}' was not created yet") diff --git a/tests/integration/test_kube.py b/tests/integration/test_kube.py index ff20b810..a4227f57 100644 --- a/tests/integration/test_kube.py +++ b/tests/integration/test_kube.py @@ -2,29 +2,32 @@ import asyncio import logging +import os import re +import tempfile import uuid -from pathlib import Path -from collections.abc import AsyncIterator, Callable, Coroutine +from collections.abc import AsyncIterator, Callable, Coroutine, Iterator from contextlib import AbstractAsyncContextManager from datetime import datetime, timedelta, timezone +from pathlib import Path from typing import Any, Union from unittest import mock from uuid import uuid4 +import aiohttp +import aiohttp.web import pytest from aiobotocore.client import AioBaseClient from aioelasticsearch import Elasticsearch from aiohttp import web from async_timeout import timeout -from yarl import URL from platform_monitoring.config import KubeConfig from platform_monitoring.kube_client import ( JobNotFoundException, KubeClient, - KubeClientException, KubeClientAuthType, + KubeClientException, PodContainerStats, PodPhase, ) @@ -39,8 +42,8 @@ ) from platform_monitoring.utils import parse_date +from .conftest import ApiAddress, create_local_app_server from .conftest_kube import MyKubeClient, MyPodDescriptor -from tests.integration.conftest import ApiAddress, create_local_app_server logger = logging.getLogger(__name__) @@ -72,15 +75,16 @@ async def _stats_summary(request: web.Request) -> web.Response: async def _gpu_metrics(request: web.Request) -> web.Response: return web.Response(content_type="text/plain") - def _unauthorized_gpu_metrics( - ) -> Callable[[web.Request], Coroutine[Any, Any, web.Response]]: - + def _unauthorized_gpu_metrics() -> ( + Callable[[web.Request], Coroutine[Any, Any, web.Response]] + ): async def _inner(request: web.Request) -> web.Response: auth_header = request.headers.get("Authorization", "") if auth_header.split(" ")[1] == "authorized": return web.Response(content_type="text/plain") else: return web.Response(status=401) + return _inner def _create_app() -> web.Application: @@ -131,6 +135,65 @@ def s3_log_service( ) +class TestKubeClientTokenUpdater: + @pytest.fixture + async def kube_app(self) -> aiohttp.web.Application: + async def _get_nodes(request: aiohttp.web.Request) -> aiohttp.web.Response: + auth = request.headers["Authorization"] + token = auth.split()[-1] + app["token"]["value"] = token + return aiohttp.web.json_response({"kind": "NodeList", "items": []}) + + app = aiohttp.web.Application() + app["token"] = {"value": ""} + app.router.add_routes([aiohttp.web.get("/api/v1/nodes", _get_nodes)]) + return app + + @pytest.fixture + async def kube_server( + self, kube_app: aiohttp.web.Application, unused_tcp_port_factory: Any + ) -> AsyncIterator[str]: + async with create_local_app_server( + kube_app, port=unused_tcp_port_factory() + ) as address: + yield f"http://{address.host}:{address.port}" + + @pytest.fixture + def kube_token_path(self) -> Iterator[str]: + _, path = tempfile.mkstemp() + Path(path).write_text("token-1") + yield path + os.remove(path) + + @pytest.fixture + async def kube_client( + self, kube_server: str, kube_token_path: str + ) -> AsyncIterator[KubeClient]: + async with KubeClient( + base_url=kube_server, + namespace="default", + auth_type=KubeClientAuthType.TOKEN, + token_path=kube_token_path, + token_update_interval_s=1, + ) as client: + yield client + + async def test_token_periodically_updated( + self, + kube_app: aiohttp.web.Application, + kube_client: KubeClient, + kube_token_path: str, + ) -> None: + await kube_client.get_nodes() + assert kube_app["token"]["value"] == "token-1" + + Path(kube_token_path).write_text("token-2") + await asyncio.sleep(2) + + await kube_client.get_nodes() + assert kube_app["token"]["value"] == "token-2" + + class TestKubeClient: async def test_wait_pod_is_running_not_found( self, kube_client: MyKubeClient @@ -473,24 +536,6 @@ async def test_create_log_stream( payload = await stream.read() assert payload == b"" - async def test_get_node_proxy_client( - self, kube_config: KubeConfig, kube_client: MyKubeClient - ) -> None: - node_list = await kube_client.get_node_list() - node_name = node_list["items"][0]["metadata"]["name"] - async with kube_client.get_node_proxy_client( - node_name, KubeConfig.kubelet_node_port - ) as client: - assert client.url == URL( - kube_config.endpoint_url - + f"/api/v1/nodes/{node_name}:{KubeConfig.kubelet_node_port}/proxy" - ) - - async with client.session.get(URL(f"{client.url}/stats/summary")) as resp: - assert resp.status == 200, await resp.text() - payload = await resp.json() - assert "node" in payload - async def test_get_nodes(self, kube_client: MyKubeClient) -> None: nodes = await kube_client.get_nodes() assert nodes @@ -534,7 +579,7 @@ async def test_get_pods( [ ["authorized", True], ["badtoken", False], - ] + ], ) async def test_get_pod_container_gpu_stats_handles_unauthorized( self,