Skip to content

Commit

Permalink
use queue
Browse files Browse the repository at this point in the history
  • Loading branch information
Windfarer committed Jun 3, 2021
1 parent 47cbaa8 commit dcae3b1
Showing 1 changed file with 36 additions and 8 deletions.
44 changes: 36 additions & 8 deletions sdk/python/kubeflow/tfjob/api/tf_job_client.py
Expand Up @@ -14,6 +14,8 @@
import multiprocessing
import time
import logging
import threading
import queue

from kubernetes import client, config
from kubernetes import watch as k8s_watch
Expand All @@ -26,8 +28,30 @@
logging.basicConfig(format='%(message)s')
logging.getLogger().setLevel(logging.INFO)

class TFJobClient(object):

def wrap_log_stream(q, stream):
while True:
try:
logline = next(stream)
q.put(logline)
except StopIteration:
q.put(None)
return
except Exception as e:
raise RuntimeError(
"Exception when calling CoreV1Api->read_namespaced_pod_log: %s\n" % e)


def get_log_queue_pool(streams):
pool = []
for stream in streams:
q = queue.Queue(maxsize=100)
pool.append(q)
threading.Thread(target=wrap_log_stream, args=(q, stream)).start()
return pool


class TFJobClient(object):
def __init__(self, config_file=None, context=None, # pylint: disable=too-many-arguments
client_configuration=None, persist_config=True):
"""
Expand Down Expand Up @@ -353,7 +377,6 @@ def get_pod_names(self, name, namespace=None, master=False, #pylint: disable=inc
else:
return set(pod_names)


def get_logs(self, name, namespace=None, master=True,
replica_type=None, replica_index=None,
follow=False):
Expand All @@ -378,28 +401,33 @@ def get_logs(self, name, namespace=None, master=True,
master=master,
replica_type=replica_type,
replica_index=replica_index))

if pod_names:
if follow:
log_streams = []
for pod in pod_names:
log_streams.append(k8s_watch.Watch().stream(self.core_api.read_namespaced_pod_log,
name=pod, namespace=namespace))
finished = [False for _ in log_streams]
# iterate over every watching pods' log

# create thread and queue per stream, for non-blocking iteration
log_queue_pool = get_log_queue_pool(log_streams)

# iterate over every watching pods' log queue
while True:
for index, stream in enumerate(log_streams):
for index, log_queue in enumerate(log_queue_pool):
if all(finished):
return
if finished[index]:
continue
# grouping the every 50 log lines of the same pod
for _ in range(50):
try:
logline = next(stream)
logline = log_queue.get(timeout=1)
if logline is None:
finished[index] = True
break
logging.info("[Pod %s]: %s", pod_names[index], logline)
except StopIteration:
finished[index] = True
except queue.Empty:
break
else:
for pod in pod_names:
Expand Down

0 comments on commit dcae3b1

Please sign in to comment.