# Train Model Job

In [17]:
from kfp.components import (
    create_component_from_func,
    InputPath,
    OutputPath
)
from typing import (
    Dict,
    NamedTuple
)

BASE_IMAGE = "quay.io/ibm/kubeflow-notebook-image-ppc64le:latest"


def train_model_job(
        dataset_directory: InputPath(str),
        train_specification: str,
        train_parameters: Dict[str, str],
        model_dir: OutputPath(str),
        train_mount: str = "/train",
        model_name: str = "my-model",
        base_image: str = "quay.io/ibm/kubeflow-notebook-image-ppc64le:latest",
        namespace: str = "",
        node_selector: str = "",
        remote_host: str = "",
        pvc_name: str = "",
        pvc_size: str = "10Gi",
        cpus: str = "8",
        gpus: int = 0,
        memory: str = "32Gi",
) -> NamedTuple(
        'TrainModelJobOutputs', [
            ('logs', str),
        ]):
    '''
    Trains a model. Once trained, the model is persisted to model_dir.

            Parameters:
                    dataset_directory: Path to the directory with training data.
                    train_specification: Training command as generated from a Python function using kfp.components.func_to_component_text.
                    train_parameters: Dictionary mapping formal to actual parameters for the training spacification.                    
                    model_dir: Target path where the model will be stored.
                    train_mount: Optional mounting point for training data of an existing PVC. Example: "/train".
                    model_name: Optional name of the model. Must be unique for the targeted namespace and conform Kubernetes naming conventions. Example: my-model.
                    base_image: Optional base image for model training. Example: quay.io/ibm/kubeflow-notebook-image-ppc64le:latest.
                    namespace: Optional namespace where the Job and associated volumes will be created. By default, the same namespace as where the pipeline is executed is chosen. Example: "user-example-com".
                    node_selector: Optional node selector for worker nodes. Example: nvidia.com/gpu.product: "Tesla-V100-SXM2-32GB".
                    pvc_name: Optional name to an existing persistent volume claim (pvc). If given, this pvs is mounted into the training job. Example: "music-genre-classification-j4ssf-training-pvc".
                    pvc_size: Optional size of the storage during model training. Storage is mounted into to the Job based on a persitent volume claim of the given size. Example: 10Gi.
                    cpus: Optional CPU limit for the job. Example: "1000m".
                    gpus: Optional number of GPUs for the job. Example: 2.
                    memory: Optional memory limit for the job. Example: "1Gi".
            Returns:
                    logs: Result outputs of the Job. Example: "...Job finished successfully".
    '''
    from collections import namedtuple
    from datetime import datetime
    import errno
    import json
    from kubernetes import (
        client,
        config,
        utils,
        watch
    )
    from kubernetes.client.rest import ApiException
    import logging
    import os
    import shutil
    import sys
    import yaml

    logging.basicConfig(
        stream=sys.stdout,
        level=logging.INFO,
        format='%(levelname)s %(asctime)s: %(message)s'
    )
    logger = logging.getLogger()

    SA_NAMESPACE = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"

    logger.info("Establishing cluster connection...")
    config.load_incluster_config()

    # init configuration variables
    epoch = datetime.today().strftime('%Y%m%d%H%M%S')
    job_name = f"job-{model_name}-{epoch}"

    if namespace == "":
        with open(SA_NAMESPACE) as f:
            namespace = f.read()
    namespace_spec = f"namespace: {namespace}"

    if node_selector != "":
        node_selector = f"nodeSelector:\n        {node_selector}"

    train_model_comp_yaml = yaml.safe_load(train_specification)
    container_yaml = train_model_comp_yaml["implementation"]["container"]
    command = container_yaml["command"]
    args = container_yaml["args"]

    pathParameters = {
        "dataset_directory": dataset_directory,
        "model_dir": model_dir
    }
    
    def clone_path(source, target):
        try:
            logger.info(f"Cloning source path {source} to {target}...")
            shutil.copytree(source, target)
            logger.info(f"Cloning finished. Target path contents:")
            logger.info(os.listdir(target))
        except OSError as e:
            if e.errno in (errno.ENOTDIR, errno.EINVAL):
                shutil.copy(source, target)
            else: raise   
    
    actual_args = list()
    outputs = list()
    for idx, arg in enumerate(args):
        if type(arg) is dict:
            if "inputValue" in arg:
                # required parameter (value)
                key = arg["inputValue"]
                if key in train_parameters:
                    actual_args.append(train_parameters[key])
                else:
                    err = f"Required parameter '{key}' missing in component input!"
                    print(err)
                    raise Exception(err)
            elif "if" in arg:
                # optional parameter
                key = arg["if"]["cond"]["isPresent"]
                if key in train_parameters:
                    actual_args.append(f"--{key}")
                    actual_args.append(train_parameters[key])
            elif "inputPath" in arg:
                # required InputPath
                key = arg["inputPath"]
                if key in train_parameters:
                    path_key = train_parameters[key]
                    if path_key in pathParameters:
                        mount = f"{train_mount}{pathParameters[path_key]}"
                        clone_path(pathParameters[path_key], mount)
                        actual_args.append(mount)
                    else:
                        err = f"InputPath '{path_key}' unavailable in training component!"
                        print(err)
                        raise Exception(err)
                else:
                    err = f"Required parameter '{key}' missing in component input!"
                    print(err)
                    raise Exception(err)
            elif "outputPath" in arg:
                # required OutputPath
                key = arg["outputPath"]
                if key in train_parameters:
                    path_key = train_parameters[key]
                    if path_key in pathParameters:
                        mount = f"{train_mount}{pathParameters[path_key]}"
                        outputs.append((mount, pathParameters[path_key]))
                        actual_args.append(mount)
                    else:
                        err = f"OutputPath '{path_key}' unavailable in training component!"
                        print(err)
                        raise Exception(err)
                else:
                    err = f"Required parameter '{key}' missing in component input!"
                    print(err)
                    raise Exception(err)
        else:
            # required parameter (key)
            actual_args.append(arg)

    train_command = json.dumps(command + actual_args)

    logger.info("=======================================")
    logger.info("Derived configurations")
    logger.info("=======================================")
    logger.info(f"job_name: {job_name}")
    logger.info(f"namespace: {namespace}")
    logger.info(f"actual_args: {actual_args}")
    logger.info(f"train_command: {train_command}")
    logger.info("=======================================")

    yaml_objects = list()

    if (pvc_name == ""):
        pvc_name = f"{job_name}-pvc"
        pvc_spec = f"""apiVersion: batch/v1
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
  name: {pvc_name}
  {namespace_spec}
spec:
  accessModes:
  - ReadWriteMany
  resources:
    requests:
      storage: {pvc_size}
"""
        yaml_objects.append(yaml.safe_load(pvc_spec))

    job_spec = f"""apiVersion: batch/v1
kind: Job
metadata:
  name: {job_name}
  {namespace_spec}
spec:
  template:
    metadata:
      annotations:
        sidecar.istio.io/inject: "false"
    spec:
      {node_selector}
      containers:
        - name: training-container
          image: {base_image}
          command: {train_command}
          volumeMounts:
            - mountPath: {train_mount}
              name: training
          restartPolicy: Never
          resources:
            limits:
              cpu: {cpus}
              memory: {memory}
              nvidia.com/gpu: {gpus}
      volumes:
        - name: training
          persistentVolumeClaim:
            claimName: {pvc_name}
      restartPolicy: Never
"""
    yaml_objects.append(yaml.safe_load(job_spec))

    logger.info(f"Starting Job '{namespace}.{job_name}'")
    utils.create_from_yaml(
        client.ApiClient(),
        yaml_objects=yaml_objects
    )

    logger.info("Reading job information...")
    job_def = client.BatchV1Api().read_namespaced_job(
        name=job_name,
        namespace=namespace
    )

    if logger.isEnabledFor(logging.DEBUG):
        logger.debug(f"Job information: {job_def}")

    logger.info("Waiting for Job to succeed...")
    w = watch.Watch()
    for event in w.stream(
        client.BatchV1Api().list_namespaced_job,
        namespace=namespace,
        label_selector=f"job-name={job_name}",
        timeout_seconds=0
    ):
        object = event['object']

        if object.status.succeeded:
            w.stop()
            logger.info("Job finished.")
            break

        if not object.status.active and object.status.failed:
            w.stop()
            logger.error("Job Failed!")
            raise Exception("Job Failed")
            
    logger.info("Receiving outputs...")
    for (source, target) in outputs:
        clone_path(source, target)
    
    logger.info("Reading logs...")
    pods_list = client.CoreV1Api().list_namespaced_pod(
        namespace=namespace,
        label_selector="controller-uid=" + job_def.metadata.labels["controller-uid"],
        timeout_seconds=10
    )
    try:
        pod_log_response = client.CoreV1Api().read_namespaced_pod_log(
            name=pods_list.items[0].metadata.name,
            namespace=namespace,
            _return_http_data_only=True,
            _preload_content=False
        )
        pod_log = pod_log_response.data.decode("utf-8")
    except ApiException as e:
        logger.error(f"Error reading logs: {e}")

    logger.info("Deleting Job resources...")
    client.BatchV1Api().delete_namespaced_job(job_name, namespace)

    logger.info("Preparing outputs...")
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    output = namedtuple(
        'TrainModelOutputs',
        ['logs']
    )

    logger.info("Finished.")
    return output(pod_log)


train_model_job_comp = create_component_from_func(
    func=train_model_job,
    output_component_file='component.yaml',
    base_image=BASE_IMAGE
)