Skip to content

Commit

Permalink
Use kr8s in common networking (#809)
Browse files Browse the repository at this point in the history
* Release 2023.8.1

* Use kr8s to start port forward in background thread

* Remove unused auth

* Duplicate networking into operator
  • Loading branch information
jacobtomlinson committed Sep 8, 2023
1 parent 32a4f61 commit ef94c94
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 18 deletions.
6 changes: 1 addition & 5 deletions dask_kubernetes/operator/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
DaskWorkerGroup,
DaskJob,
)
from dask_kubernetes.common.auth import ClusterAuth
from dask_kubernetes.common.networking import get_scheduler_address
from dask_kubernetes.operator.networking import get_scheduler_address
from distributed.core import rpc, clean_exception
from distributed.protocol.pickle import dumps
import dask.config
Expand Down Expand Up @@ -246,9 +245,6 @@ def build_cluster_spec(name, worker_spec, scheduler_spec, annotations, labels):

@kopf.on.startup()
async def startup(settings: kopf.OperatorSettings, **kwargs):
# Authenticate with k8s
await ClusterAuth.load_first()

# Set server and client timeouts to reconnect from time to time.
# In rare occasions the connection might go idle we will no longer receive any events.
# These timeouts should help in those cases.
Expand Down
18 changes: 6 additions & 12 deletions dask_kubernetes/operator/kubecluster/kubecluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@
format_dashboard_link,
)

from dask_kubernetes.common.auth import ClusterAuth
from dask_kubernetes.common.networking import (
from dask_kubernetes.operator.networking import (
get_scheduler_address,
wait_for_scheduler,
wait_for_scheduler_comm,
)
from dask_kubernetes.common.utils import get_current_namespace
from dask_kubernetes.exceptions import CrashLoopBackOffError, SchedulerStartupError
from dask_kubernetes.operator._objects import (
DaskCluster,
Expand Down Expand Up @@ -86,9 +84,6 @@ class KubeCluster(Cluster):
The command to use when starting the worker.
If command consists of multiple words it should be passed as a list of strings.
Defaults to ``"dask-worker"``.
auth: List[ClusterAuth] (optional)
Configuration methods to attempt in order. Defaults to
``[InCluster(), KubeConfig()]``.
port_forward_cluster_ip: bool (optional)
If the chart uses ClusterIP type services, forward the
ports locally. If you are running it locally it should
Expand Down Expand Up @@ -172,7 +167,6 @@ def __init__(
resources=None,
env=None,
worker_command=None,
auth=ClusterAuth.DEFAULT,
port_forward_cluster_ip=None,
create_mode=None,
shutdown_on_close=None,
Expand All @@ -187,9 +181,8 @@ def __init__(
**kwargs,
):
name = dask.config.get("kubernetes.name", override_with=name)
self.namespace = (
dask.config.get("kubernetes.namespace", override_with=namespace)
or get_current_namespace()
self.namespace = dask.config.get(
"kubernetes.namespace", override_with=namespace
)
self.image = dask.config.get("kubernetes.image", override_with=image)
self.n_workers = dask.config.get(
Expand All @@ -207,7 +200,6 @@ def __init__(
self.worker_command = dask.config.get(
"kubernetes.worker-command", override_with=worker_command
)
self.auth = auth
self.port_forward_cluster_ip = dask.config.get(
"kubernetes.port-forward-cluster-ip", override_with=port_forward_cluster_ip
)
Expand Down Expand Up @@ -295,13 +287,15 @@ def dashboard_link(self):
return format_dashboard_link(host, self.forwarded_dashboard_port)

async def _start(self):
if not self.namespace:
api = await kr8s.asyncio.api()
self.namespace = api.namespace
try:
watch_component_status_task = asyncio.create_task(
self._watch_component_status()
)
if not self.quiet:
show_rich_output_task = asyncio.create_task(self._show_rich_output())
await ClusterAuth.load_first(self.auth)
cluster = await DaskCluster(self.name, namespace=self.namespace)
cluster_exists = await cluster.exists()

Expand Down
238 changes: 238 additions & 0 deletions dask_kubernetes/operator/networking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import asyncio
from contextlib import suppress
import random
import socket
import time
import threading
from tornado.iostream import StreamClosedError

import kr8s
from kr8s.asyncio.objects import Pod, Service
from distributed.core import rpc

from dask_kubernetes.exceptions import CrashLoopBackOffError


async def get_internal_address_for_scheduler_service(
service,
port_forward_cluster_ip=None,
service_name_resolution_retries=20,
port_name="tcp-comm",
local_port=None,
):
"""Take a service object and return the scheduler address."""
port = _get_port(service, port_name)
if not port_forward_cluster_ip:
with suppress(socket.gaierror):
# Try to resolve the service name. If we are inside the cluster this should succeed.
host = f"{service.metadata.name}.{service.metadata.namespace}"
if await _is_service_available(
host=host, port=port, retries=service_name_resolution_retries
):
return f"tcp://{host}:{port}"

# If the service name is unresolvable, we are outside the cluster and we need to port forward the service.
host = "localhost"

port = await port_forward_service(
service.metadata.name, service.metadata.namespace, port, local_port
)
return f"tcp://{host}:{port}"


async def get_external_address_for_scheduler_service(
service,
port_forward_cluster_ip=None,
service_name_resolution_retries=20,
port_name="tcp-comm",
local_port=None,
):
"""Take a service object and return the scheduler address."""
if service.spec.type == "LoadBalancer":
port = _get_port(service, port_name)
lb = service.status.load_balancer.ingress[0]
host = lb.hostname or lb.ip
elif service.spec.type == "NodePort":
port = _get_port(service, port_name, is_node_port=True)
nodes = await kr8s.asyncio.get("nodes")
host = nodes[0].status.addresses[0].address
elif service.spec.type == "ClusterIP":
port = _get_port(service, port_name)
if not port_forward_cluster_ip:
with suppress(socket.gaierror):
# Try to resolve the service name. If we are inside the cluster this should succeed.
host = f"{service.metadata.name}.{service.metadata.namespace}"
if await _is_service_available(
host=host, port=port, retries=service_name_resolution_retries
):
return f"tcp://{host}:{port}"

# If the service name is unresolvable, we are outside the cluster and we need to port forward the service.
host = "localhost"

port = await port_forward_service(
service.metadata.name, service.metadata.namespace, port, local_port
)
return f"tcp://{host}:{port}"


def _get_port(service, port_name, is_node_port=False):
"""NodePort is a special case when we have to use node_port instead of node"""
[port] = [
port.port if not is_node_port else port.node_port
for port in service.spec.ports
if port.name == service.metadata.name or port.name == port_name
]
return port


async def _is_service_available(host, port, retries=20):
for i in range(retries):
try:
return await asyncio.get_event_loop().getaddrinfo(host, port)
except socket.gaierror as e:
if i >= retries - 1:
raise e
await asyncio.sleep(0.5)


def _port_in_use(port):
if port is None:
return True
conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
conn.bind(("", port))
conn.close()
return False
except OSError:
return True


def _random_free_port(low, high, retries=20):
conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while retries:
guess = random.randint(low, high)
try:
conn.bind(("", guess))
conn.close()
return guess
except OSError:
retries -= 1
raise ConnectionError("Not able to find a free port.")


async def port_forward_service(service_name, namespace, remote_port, local_port=None):
if not local_port:
local_port = _random_free_port(49152, 65535) # IANA suggested range
elif _port_in_use(local_port):
raise ConnectionError("Specified Port already in use.")
pf = threading.Thread(
name=f"DaskKubernetesPortForward ({namespace}/{service_name} {local_port}->{remote_port})",
target=run_port_forward,
args=(
service_name,
namespace,
remote_port,
local_port,
),
daemon=True,
)
pf.start()

if await is_comm_open("localhost", local_port, retries=2000):
return local_port
raise ConnectionError("port forward failed")


def run_port_forward(service_name, namespace, remote_port, local_port):
async def _run():
svc = await Service.get(service_name, namespace=namespace)
async with svc.portforward(remote_port, local_port):
while True:
await asyncio.sleep(0.1)

asyncio.run(_run())


async def is_comm_open(ip, port, retries=200):
while retries > 0:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
result = sock.connect_ex((ip, port))
if result == 0:
return True
else:
time.sleep(0.1)
retries -= 1
return False


async def port_forward_dashboard(service_name, namespace):
port = await port_forward_service(service_name, namespace, 8787)
return port


async def get_scheduler_address(
service_name,
namespace,
port_name="tcp-comm",
port_forward_cluster_ip=None,
local_port=None,
allow_external=True,
):
service = await Service.get(service_name, namespace=namespace)
if allow_external:
address = await get_external_address_for_scheduler_service(
service,
port_forward_cluster_ip=port_forward_cluster_ip,
port_name=port_name,
local_port=local_port,
)
else:
address = await get_internal_address_for_scheduler_service(
service,
port_forward_cluster_ip=port_forward_cluster_ip,
port_name=port_name,
local_port=local_port,
)
return address


async def wait_for_scheduler(cluster_name, namespace, timeout=None):
pod_start_time = None
while True:
try:
pod = await Pod.get(
label_selector=f"dask.org/component=scheduler,dask.org/cluster-name={cluster_name}",
namespace=namespace,
)
except kr8s.NotFoundError:
await asyncio.sleep(0.25)
continue
if pod.status.phase == "Running":
if not pod_start_time:
pod_start_time = time.time()
if await pod.ready():
return
if "containerStatuses" in pod.status:
for container in pod.status.containerStatuses:
if (
"waiting" in container.state
and container.state.waiting.reason == "CrashLoopBackOff"
and timeout
and pod_start_time + timeout < time.time()
):
raise CrashLoopBackOffError(
f"Scheduler in CrashLoopBackOff for more than {timeout} seconds."
)
await asyncio.sleep(0.25)


async def wait_for_scheduler_comm(address):
while True:
try:
async with rpc(address) as scheduler_comm:
await scheduler_comm.versions()
except (StreamClosedError, OSError):
await asyncio.sleep(0.1)
continue
break
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ kubernetes-asyncio>=12.0.1
kopf>=1.35.3
pykube-ng>=22.9.0
rich>=12.5.1
kr8s==0.8.15
kr8s==0.8.16

0 comments on commit ef94c94

Please sign in to comment.