In [1]:
# TODO (andreyvelich): Change to release version when SDK with the new APIs is published.
!pip install git+https://github.com/kubeflow/training-operator.git#subdirectory=sdk/python

Collecting git+https://github.com/kubeflow/training-operator.git#subdirectory=sdk/python
  Cloning https://github.com/kubeflow/training-operator.git to /tmp/pip-req-build-run_74fp
  Running command git clone --filter=blob:none --quiet https://github.com/kubeflow/training-operator.git /tmp/pip-req-build-run_74fp
  Resolved https://github.com/kubeflow/training-operator.git to commit 9e46f9d422e71f258679c5edd306c7eddf9004f1
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: kubeflow-training
  Building wheel for kubeflow-training (setup.py) ... [?25ldone
[?25h  Created wheel for kubeflow-training: filename=kubeflow_training-1.8.1-py3-none-any.whl size=140130 sha256=aca5837d34b51a293086fa39686d2c7a2708837d8aeaf267ef8cfb272d9d8fa3
  Stored in directory: /tmp/pip-ephem-wheel-cache-d95lcrol/wheels/4e/97/bb/7c46e489ad7772669c94e462b1f545c475d32d70259ba08209
Successfully built kubeflow-training
Installing collected packages: kubeflow-training
Successfully i

In [2]:
import gradio as gr
import time
import textwrap
import os
import kubernetes

from kubernetes.client import V1ResourceRequirements, V1VolumeMount, V1Volume, V1PersistentVolumeClaimVolumeSource, V1Container, V1PodSpec, V1PodTemplateSpec, V1ObjectMeta
from kubeflow.training import TrainingClient
from kubeflow.training.models import KubeflowOrgV1ReplicaSpec, KubeflowOrgV1PyTorchJob, KubeflowOrgV1PyTorchJobSpec, KubeflowOrgV1RunPolicy
from kubeflow.training.constants import constants

In [3]:
class PyTorchJobManager:
    def __init__(self, namespace=None, name="pytorch-dist-mnist-gloo", image=None, command=None, working_dir=None):
        self.namespace = namespace
        self.name = name
        self.container_name = "pytorch"
        self.training_client = TrainingClient(namespace=namespace)
        self.job_running = False
        self.image = image
        self.command = command
        self.working_dir = working_dir
        self.last_logs = ""

    # 建立 PyTorch 工作
    def create_pytorch_job(self):
        image = self.image
        command = self.command
        working_dir = self.working_dir
        volume_claim_name = self._get_volume_claim_name()
        volume_mount = V1VolumeMount(
            name="model-volume",
            mount_path="/home/jovyan/",
        )

        volume = V1Volume(
            name="model-volume",
            persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(claim_name=volume_claim_name)
        )

        container = V1Container(
            name=self.container_name,
            image=image,
            command=command,
            working_dir=working_dir,
            resources=V1ResourceRequirements(
                requests={
                    "cpu": "4",
                    "memory": "8Gi",
                    "nvidia.com/gpu": "1"
                },
                limits={
                    "cpu": "4",
                    "memory": "8Gi",
                    "nvidia.com/gpu": "1"
                }
            ),
            volume_mounts=[volume_mount],
        )
        pod_spec = V1PodSpec(
            containers=[container],
            volumes=[volume]
        )

        replica_spec = KubeflowOrgV1ReplicaSpec(
            replicas=4,
            restart_policy="OnFailure",
            template=V1PodTemplateSpec(
                metadata=V1ObjectMeta(
                    name=self.name,
                    namespace=self.namespace,
                    annotations={
                        "sidecar.istio.io/inject": "false"
                    }
                ),
                spec=V1PodSpec(
                    containers=[container],
                    volumes=[V1Volume(
                        name="model-volume",
                        persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(claim_name=volume_claim_name)
                    )]
                )
            )
        )

        master_replica_spec = KubeflowOrgV1ReplicaSpec(
            replicas=1,
            restart_policy="OnFailure",
            template=V1PodTemplateSpec(
                metadata=V1ObjectMeta(
                    name=self.name,
                    namespace=self.namespace,
                    annotations={
                        "sidecar.istio.io/inject": "false"
                    }
                ),
                spec=V1PodSpec(
                    containers=[container],
                    volumes=[V1Volume(
                        name="model-volume",
                        persistent_volume_claim=V1PersistentVolumeClaimVolumeSource(claim_name=volume_claim_name)
                    )]
                )
            )
        )

        pytorchjob = KubeflowOrgV1PyTorchJob(
            api_version=constants.API_VERSION,
            kind=constants.PYTORCHJOB_KIND,
            metadata=V1ObjectMeta(name=self.name, namespace=self.namespace),
            spec=KubeflowOrgV1PyTorchJobSpec(
                run_policy=KubeflowOrgV1RunPolicy(clean_pod_policy="None"),
                pytorch_replica_specs={
                    "Master": master_replica_spec,
                    "Worker": replica_spec
                },
            ),
        )

        self.training_client.create_job(pytorchjob)
        self.job_running = True

    # 獲取卷聲明名稱
    def _get_volume_claim_name(self):
        notebook_name = os.environ.get('HOSTNAME', 'notebook').split('-')[0]
        kubernetes.config.load_incluster_config()
        crd_api = kubernetes.client.CustomObjectsApi()
        crd_group = 'kubeflow.org'
        crd_version = 'v1alpha1'
        crd_plural = 'notebooks'
        notebook = crd_api.get_namespaced_custom_object(crd_group, crd_version, self.namespace, crd_plural, notebook_name)
        pvc_name = notebook['spec']['template']['spec']['volumes'][1]['persistentVolumeClaim']['claimName']
        return pvc_name

    # 獲取工作
    def get_job(self):
        return self.training_client.get_job(self.name, job_kind=constants.PYTORCHJOB_KIND)

    # 等待工作完成
    def wait_for_job(self, wait_timeout=900):
        return self.training_client.wait_for_job_conditions(
            name=self.name,
            job_kind=constants.PYTORCHJOB_KIND,
            wait_timeout=wait_timeout,
        )

    # 檢查工作是否成功
    def is_job_succeeded(self):
        return self.training_client.is_job_succeeded(name=self.name, job_kind=constants.PYTORCHJOB_KIND)

    # 獲取工作日誌
    def get_job_logs(self):
        return self.training_client.get_job_logs(name=self.name, job_kind=constants.PYTORCHJOB_KIND)
    
    # 刪除工作
    def delete_job(self):
        result = self.training_client.delete_job(self.name)
        self.job_running = False
        return result
    
    def fetch_logs(self):
        try:
            logs, _ = self.get_job_logs()
            return logs.get(f'{self.name}-master-0', None)
        except Exception as e:
            return f"Error while fetching logs: {e}"

    def display_logs(self):
        while not self.is_job_succeeded():
            log_content = self.fetch_logs()
            yield log_content if log_content is not None else "No logs available."
            time.sleep(1)
        yield "Completed"
    
    # 顯示 Gradio 介面
    def show_gradio_interface(self):
        def run_and_display():
            log_content = self.fetch_logs()
            if log_content is None:
                if not self.job_running:
                    self.create_pytorch_job()
                for log in self.display_logs():
                    yield log
            elif self.is_job_succeeded():
                yield "Completed"
            
        def clear_and_delete():
            log_content = self.fetch_logs()
            if log_content is None:
                yield "需先點擊開始訓練"
            elif log_content == 'Terminating':
                yield "Terminating"
            else:
                self.delete_job()
                yield "Terminating"

        def check_status():
            log_content = self.fetch_logs()
            if log_content is None:
                yield "需先點擊開始訓練"
            elif self.is_job_succeeded():
                yield "Completed"
            else:
                yield log_content if log_content else "No logs available."
            
        # 自定義JavaScript邏輯，實現自動滾動
        custom_js = """
        <script>
        function scrollToBottom(id) {
            var textbox = document.getElementById(id);
            if (textbox) {
                textbox.scrollTop = textbox.scrollHeight;
            }
        }

        function addAutoScroll(textbox_id) {
            var textbox = document.getElementById(textbox_id);
            if (textbox) {
                var isAutoScroll = true;
                textbox.addEventListener('scroll', function() {
                    if (textbox.scrollTop + textbox.clientHeight < textbox.scrollHeight) {
                        isAutoScroll = false;
                    } else {
                        isAutoScroll = true;
                    }
                });

                setInterval(function() {
                    if (isAutoScroll) {
                        scrollToBottom(textbox_id);
                    }
                }, 100);
            }
        }

        // 在文本框加載完成後初始化自動滾動
        document.addEventListener("DOMContentLoaded", function() {
            addAutoScroll('training_logs_textbox');
        });
        </script>
        """

        # CSS to fix the height of the Textbox
        custom_css = """
        <style>
        #training_logs_textbox {
            height: 300px;
            overflow: auto;
        }
        </style>
        """

        with gr.Blocks() as demo:
            gr.Markdown("點擊按鈕開始訓練模型並顯示日誌")
            
            # 使用 gr.Textbox
            out = gr.Textbox(
                label="Training Logs",
                elem_id="training_logs_textbox",
                lines=13,  # 設置固定的行數
            )  
            with gr.Row():
                btn_start = gr.Button("開始訓練")
                btn_look = gr.Button("察看進度")
                btn_clear = gr.Button("清除日誌")

            btn_start.click(fn=run_and_display, inputs=[], outputs=out)
            btn_look.click(fn=check_status, inputs=[], outputs=out)
            btn_clear.click(fn=clear_and_delete, inputs=[], outputs=out)  # 清除日誌功能

            gr.HTML(custom_css)  # 插入自定義CSS
            gr.HTML(custom_js)   # 插入自定義JavaScript

        demo.launch(share=True)
        
    # 等待並顯示日誌
    def wait_logs(self, ui=True):
        if ui:
            self.show_gradio_interface()
        else:
            try:
                while not self.is_job_succeeded():
                    log_content = self.fetch_logs()
                    if log_content != self.last_logs:
                        new_logs = log_content.replace(self.last_logs, '') if log_content else ''
                        wrapped_logs = textwrap.fill(new_logs, width=80)
                        print(f"\n--- New Logs at {time.strftime('%Y-%m-%d %H:%M:%S')} ---\n")
                        print(wrapped_logs)
                        self.last_logs = log_content
                    time.sleep(1)
                print("Completed")
            except Exception as e:
                print(f"發生錯誤: {e}")
                raise

# 使用範例
manager = PyTorchJobManager(
    namespace="dm1261010",
    image="cguaicadmin/newlab-newpytorch:V1.0.23",
    command=["/home/jovyan/test.sh"],
    working_dir="/home/jovyan"
)

In [4]:
manager.wait_logs()

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://2c0748e3646e52c5ad.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


In [5]:
# logs, _ = manager.training_client.get_job_logs(manager.name)
# print(logs['pytorch-dist-mnist-gloo-master-0'])

In [6]:
# manager.delete_job()