Skip to content

Commit

Permalink
add update kube token task (#822)
Browse files Browse the repository at this point in the history
  • Loading branch information
zubenkoivan committed Jul 17, 2023
1 parent 8a53d67 commit 9e80638
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 98 deletions.
120 changes: 49 additions & 71 deletions platform_monitoring/kube_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
import ssl
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
Expand Down Expand Up @@ -315,6 +315,7 @@ def __init__(
auth_cert_key_path: Optional[str] = None,
token: Optional[str] = None,
token_path: Optional[str] = None,
token_update_interval_s: int = 300,
conn_timeout_s: int = 300,
read_timeout_s: int = 100,
conn_pool_size: int = 100,
Expand All @@ -333,6 +334,7 @@ def __init__(
self._auth_cert_key_path = auth_cert_key_path
self._token = token
self._token_path = token_path
self._token_update_interval_s = token_update_interval_s

self._conn_timeout_s = conn_timeout_s
self._read_timeout_s = read_timeout_s
Expand All @@ -344,6 +346,7 @@ def __init__(
self._trace_configs = trace_configs

self._client: Optional[aiohttp.ClientSession] = None
self._token_updater_task: Optional[asyncio.Task[None]] = None

@property
def _is_ssl(self) -> bool:
Expand All @@ -363,43 +366,35 @@ def _create_ssl_context(self) -> Optional[ssl.SSLContext]:
return ssl_context

async def init(self) -> None:
self._client = await self.create_http_client()

async def init_if_needed(self) -> None:
if not self._client or self._client.closed:
await self.init()

async def create_http_client(
self, *, force_close: bool = False
) -> aiohttp.ClientSession:
connector = aiohttp.TCPConnector(
limit=self._conn_pool_size,
ssl=self._create_ssl_context(),
force_close=force_close,
limit=self._conn_pool_size, ssl=self._create_ssl_context()
)
if self._auth_type == KubeClientAuthType.TOKEN:
token = self._token
if not token:
assert self._token_path is not None
token = Path(self._token_path).read_text()
headers = {"Authorization": "Bearer " + token}
else:
headers = {}
if self._token_path:
self._token = Path(self._token_path).read_text()
self._token_updater_task = asyncio.create_task(self._start_token_updater())
timeout = aiohttp.ClientTimeout(
connect=self._conn_timeout_s, total=self._read_timeout_s
)
return aiohttp.ClientSession(
self._client = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers=headers,
trace_configs=self._trace_configs,
)

async def _recreate_http_client(self) -> None:
logger.warning("Reloading http K8s client.")
await self.close()
self._token = None
await self.init()
async def _start_token_updater(self) -> None:
if not self._token_path:
return
while True:
try:
token = Path(self._token_path).read_text()
if token != self._token:
self._token = token
logger.info("Kube token was refreshed")
except asyncio.CancelledError:
raise
except Exception as exc:
logger.exception("Failed to update kube token: %s", exc)
await asyncio.sleep(self._token_update_interval_s)

@property
def namespace(self) -> str:
Expand All @@ -409,6 +404,11 @@ async def close(self) -> None:
if self._client:
await self._client.close()
self._client = None
if self._token_updater_task:
self._token_updater_task.cancel()
with suppress(asyncio.CancelledError):
await self._token_updater_task
self._token_updater_task = None

async def __aenter__(self) -> "KubeClient":
await self.init()
Expand Down Expand Up @@ -460,24 +460,22 @@ def _generate_pod_log_url(self, pod_name: str, container_name: str) -> str:
url = f"{url}/log?container={pod_name}&follow=true"
return url

def _create_headers(
self, headers: Optional[dict[str, Any]] = None
) -> dict[str, Any]:
headers = dict(headers) if headers else {}
if self._auth_type == KubeClientAuthType.TOKEN and self._token:
headers["Authorization"] = "Bearer " + self._token
return headers

async def _request(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
await self.init_if_needed()
headers = self._create_headers(kwargs.pop("headers", None))
assert self._client, "client is not initialized"
doing_retry = kwargs.pop("doing_retry", False)

async with self._client.request(*args, **kwargs) as response:
try:
await self._check_response_status(response)
payload = await response.json()
logger.debug("k8s response payload: %s", payload)
except KubeClientUnauthorized:
if doing_retry:
raise
await self._recreate_http_client()
kwargs["doing_retry"] = True
payload = await self._request(*args, **kwargs)

return payload
async with self._client.request(*args, headers=headers, **kwargs) as response:
await self._check_response_status(response)
payload = await response.json()
logger.debug("k8s response payload: %s", payload)
return payload

async def get_raw_pod(self, pod_name: str) -> dict[str, Any]:
url = self._generate_pod_url(pod_name)
Expand Down Expand Up @@ -534,18 +532,6 @@ async def wait_pod_is_not_waiting(
return status
await asyncio.sleep(interval_s)

def _get_node_proxy_url(self, host: str, port: int) -> URL:
return URL(self._generate_node_proxy_url(host, port))

@asynccontextmanager
async def get_node_proxy_client(
self, host: str, port: int
) -> AsyncIterator[ProxyClient]:
assert self._client
yield ProxyClient(
url=self._get_node_proxy_url(host, port), session=self._client
)

async def get_pod_container_stats(
self, node_name: str, pod_name: str, container_name: str
) -> Optional["PodContainerStats"]:
Expand All @@ -570,26 +556,20 @@ async def get_pod_container_gpu_stats(
node_name: str,
pod_name: str,
container_name: str,
doing_retry: bool = False,
) -> Optional["PodContainerGPUStats"]:
url = self._generate_node_gpu_metrics_url(node_name)
if not url:
return None
try:
await self.init_if_needed()
assert self._client
async with self._client.get(url, raise_for_status=True) as resp:
async with self._client.get(
url, headers=self._create_headers(), raise_for_status=True
) as resp:
text = await resp.text()
gpu_counters = GPUCounters.parse(text)
return gpu_counters.get_pod_container_stats(
self._namespace, pod_name, container_name
)
except aiohttp.ClientResponseError as e:
if e.status == 401 and not doing_retry:
await self._recreate_http_client()
return await self.get_pod_container_gpu_stats(
node_name, pod_name, container_name, doing_retry=True,
)
except aiohttp.ClientError as e:
logger.exception(e)
return None
Expand Down Expand Up @@ -624,9 +604,9 @@ async def create_pod_container_logs_stream(
client_timeout = aiohttp.ClientTimeout(
connect=conn_timeout_s, sock_read=read_timeout_s
)
await self.init_if_needed()
async with self._client.get( # type: ignore
url, timeout=client_timeout
assert self._client
async with self._client.get(
url, headers=self._create_headers(), timeout=client_timeout
) as response:
await self._check_response_status(response, job_id=pod_name)
yield response.content
Expand Down Expand Up @@ -678,9 +658,7 @@ def _assert_resource_kind(
elif kind != expected_kind:
raise ValueError(f"unknown kind: {kind}")

def _raise_for_status(
self, payload: dict[str, Any], job_id: Optional[str]
) -> None:
def _raise_for_status(self, payload: dict[str, Any], job_id: Optional[str]) -> None:
if payload["code"] == 400:
if "ContainerCreating" in payload["message"]:
raise JobNotFoundException(f"job '{job_id}' was not created yet")
Expand Down
99 changes: 72 additions & 27 deletions tests/integration/test_kube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,32 @@

import asyncio
import logging
import os
import re
import tempfile
import uuid
from pathlib import Path
from collections.abc import AsyncIterator, Callable, Coroutine
from collections.abc import AsyncIterator, Callable, Coroutine, Iterator
from contextlib import AbstractAsyncContextManager
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Union
from unittest import mock
from uuid import uuid4

import aiohttp
import aiohttp.web
import pytest
from aiobotocore.client import AioBaseClient
from aioelasticsearch import Elasticsearch
from aiohttp import web
from async_timeout import timeout
from yarl import URL

from platform_monitoring.config import KubeConfig
from platform_monitoring.kube_client import (
JobNotFoundException,
KubeClient,
KubeClientException,
KubeClientAuthType,
KubeClientException,
PodContainerStats,
PodPhase,
)
Expand All @@ -39,8 +42,8 @@
)
from platform_monitoring.utils import parse_date

from .conftest import ApiAddress, create_local_app_server
from .conftest_kube import MyKubeClient, MyPodDescriptor
from tests.integration.conftest import ApiAddress, create_local_app_server

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,15 +75,16 @@ async def _stats_summary(request: web.Request) -> web.Response:
async def _gpu_metrics(request: web.Request) -> web.Response:
return web.Response(content_type="text/plain")

def _unauthorized_gpu_metrics(
) -> Callable[[web.Request], Coroutine[Any, Any, web.Response]]:

def _unauthorized_gpu_metrics() -> (
Callable[[web.Request], Coroutine[Any, Any, web.Response]]
):
async def _inner(request: web.Request) -> web.Response:
auth_header = request.headers.get("Authorization", "")
if auth_header.split(" ")[1] == "authorized":
return web.Response(content_type="text/plain")
else:
return web.Response(status=401)

return _inner

def _create_app() -> web.Application:
Expand Down Expand Up @@ -131,6 +135,65 @@ def s3_log_service(
)


class TestKubeClientTokenUpdater:
@pytest.fixture
async def kube_app(self) -> aiohttp.web.Application:
async def _get_nodes(request: aiohttp.web.Request) -> aiohttp.web.Response:
auth = request.headers["Authorization"]
token = auth.split()[-1]
app["token"]["value"] = token
return aiohttp.web.json_response({"kind": "NodeList", "items": []})

app = aiohttp.web.Application()
app["token"] = {"value": ""}
app.router.add_routes([aiohttp.web.get("/api/v1/nodes", _get_nodes)])
return app

@pytest.fixture
async def kube_server(
self, kube_app: aiohttp.web.Application, unused_tcp_port_factory: Any
) -> AsyncIterator[str]:
async with create_local_app_server(
kube_app, port=unused_tcp_port_factory()
) as address:
yield f"http://{address.host}:{address.port}"

@pytest.fixture
def kube_token_path(self) -> Iterator[str]:
_, path = tempfile.mkstemp()
Path(path).write_text("token-1")
yield path
os.remove(path)

@pytest.fixture
async def kube_client(
self, kube_server: str, kube_token_path: str
) -> AsyncIterator[KubeClient]:
async with KubeClient(
base_url=kube_server,
namespace="default",
auth_type=KubeClientAuthType.TOKEN,
token_path=kube_token_path,
token_update_interval_s=1,
) as client:
yield client

async def test_token_periodically_updated(
self,
kube_app: aiohttp.web.Application,
kube_client: KubeClient,
kube_token_path: str,
) -> None:
await kube_client.get_nodes()
assert kube_app["token"]["value"] == "token-1"

Path(kube_token_path).write_text("token-2")
await asyncio.sleep(2)

await kube_client.get_nodes()
assert kube_app["token"]["value"] == "token-2"


class TestKubeClient:
async def test_wait_pod_is_running_not_found(
self, kube_client: MyKubeClient
Expand Down Expand Up @@ -473,24 +536,6 @@ async def test_create_log_stream(
payload = await stream.read()
assert payload == b""

async def test_get_node_proxy_client(
self, kube_config: KubeConfig, kube_client: MyKubeClient
) -> None:
node_list = await kube_client.get_node_list()
node_name = node_list["items"][0]["metadata"]["name"]
async with kube_client.get_node_proxy_client(
node_name, KubeConfig.kubelet_node_port
) as client:
assert client.url == URL(
kube_config.endpoint_url
+ f"/api/v1/nodes/{node_name}:{KubeConfig.kubelet_node_port}/proxy"
)

async with client.session.get(URL(f"{client.url}/stats/summary")) as resp:
assert resp.status == 200, await resp.text()
payload = await resp.json()
assert "node" in payload

async def test_get_nodes(self, kube_client: MyKubeClient) -> None:
nodes = await kube_client.get_nodes()
assert nodes
Expand Down Expand Up @@ -534,7 +579,7 @@ async def test_get_pods(
[
["authorized", True],
["badtoken", False],
]
],
)
async def test_get_pod_container_gpu_stats_handles_unauthorized(
self,
Expand Down

0 comments on commit 9e80638

Please sign in to comment.