# Training the Fraud Detection model with the Kubeflow Training Operator

The example fraud detection model is small and quickly trained. For many large models, training requires multiple GPUs and often multiple machines. In this notebook, you learn how to train a model by using the Kubeflow Training Operator on OpenShift AI to scale out model training. You use the Training Operator SDK to create a PyTorchJob that executes the provided model training script.

### Install the Training Operator SDK

The Training Operator SDK is not available by default with the Tensorflow workbench image. Run the following command to install it:

In [None]:
%pip install -qqU kubeflow-training==1.9.2

### Prepare the data

Typically, the training data for your model is available in a shared location. For this example, the data is local. You upload it to your object storage so that you can learn how to load data from a shared data source. The provided model training script downloads the training data. The PyTorch DistributedSampler utility distributes the tasks among worker nodes.

In [None]:
import sys
sys.path.append('./utils')

import utils.s3

utils.s3.upload_directory_to_s3("data", "data")
print("---")
utils.s3.list_objects("data")

### Authenticate the Training Operator SDK to the OpenShift cluster

The Training Operator SDK requires authenticated access to the OpenShift cluster so that it can create PyTorchJobs. The easiest way to get access details is by using the OpenShift web console. 
 

1. To generate the command, select **Copy login command** from the username drop-down menu at the top right of the OpenShift web console.

    <figure>
        <img src="./assets/copy-login.png"  alt="copy login"  >
    <figure/>

2. Click **Display token**.

3. Below **Log in with this token**, take note of the parameters for token and server.
   For example:
    ```
    oc login --token=sha256~LongString --server=https://api.your-cluster.domain.com:6443
    ```    
    - token: `sha256~LongString`
    - server: `https://api.your-cluster.domain.com:6443`
    
4. In the following code cell, replace the token and server values with the values that you noted in Step 3.
   For example:
   ```
   api_server = "https://api.your-cluster.domain.com:6443"
   token = "sha256~LongString"
   ```


In [None]:
from kubernetes import client

api_server = "https://XXXX"
token = "sha256~XXXX"

configuration = client.Configuration()
configuration.host = api_server
configuration.api_key = {"authorization": f"Bearer {token}"}
# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA
#configuration.verify_ssl = False

## Running the distributed training

### Initialize the Training client

Initialize the Training client by using the provided user credentials.

In [None]:
from kubeflow.training import TrainingClient

client = TrainingClient(client_configuration=configuration)

### Create a PyTorchJob

Use the Training Operator SDK client to submit a PyTorchJob.

The model training script is imported from the `kfto-scripts` folder.

The model training script loads and distributes the training data set among nodes, performs distributed training, evaluates by using the test data set, and exports the trained model to ONNX format and uploads it to the S3 bucket that is specified in the provided connection.

*Important note:* If the following is true, you must edit the script in the next Python cell to uncomment the label declaration in the `create_job` function.

* You are not using the Red Hat Sandbox test environment. 

* The Kueue component is enabled for OpenShift AI and you have created all Kueue related resources (`ResourceFlavor`, `ClusterQueue`, and `LocalQueue`) and set the `local_queue_name` to "local-queue", as described in the _Setting up Kueue resources_ section of this Fraud Detection workshop/tutorial.

In [None]:
import sys
import os
sys.path.append("./kfto-scripts")  # needed to make training function available in the notebook
from train_pytorch_cpu import train_func
from kubernetes.client import (
    V1EnvVar,
    V1EnvVarSource,
    V1SecretKeySelector
)

# Job name serves as a unique identifier to retrieve job-related information by using the SDK
job_name = "fraud-detection"

# If the Kueue component is enabled, and you have created the Kueue-related resources (ResourceFlavor, ClusterQueue and LocalQueue), provide the LocalQueue name on the following line:
local_queue_name = "local-queue"

client.create_job(
    job_kind="PyTorchJob",
    name=job_name,
    train_func=train_func,
    num_workers=2,
    num_procs_per_worker="1",
    resources_per_worker={
        "memory": "4Gi",
        "cpu": 1,
    },
    base_image="quay.io/modh/training:py311-cuda124-torch251",
    # If the Kueue component is enabled and you have created the Kueue-related resources (ResourceFlavor, ClusterQueue and LocalQueue), then uncomment the following line to add the queue-name label:
    # labels={"kueue.x-k8s.io/queue-name": "local-queue"},
    env_vars=[
        V1EnvVar(name="AWS_ACCESS_KEY_ID", value=os.environ.get("AWS_ACCESS_KEY_ID")),
        V1EnvVar(name="AWS_S3_BUCKET", value=os.environ.get("AWS_S3_BUCKET")),
        V1EnvVar(name="AWS_S3_ENDPOINT", value=os.environ.get("AWS_S3_ENDPOINT")),
        V1EnvVar(name="AWS_SECRET_ACCESS_KEY", value=os.environ.get("AWS_SECRET_ACCESS_KEY")),
    ],
    packages_to_install=[
        "s3fs",
        "boto3",
        "scikit-learn",
        "onnx",
    ],
)

### Query important job information

In [None]:
import time


# Wait until the job finishes
print(f"PyTorchJob '{job_name}' is running.", end='')
while True:
    try:
        if client.is_job_running(name=job_name):
            print(".", end='')
        elif client.is_job_succeeded(name=job_name):
            print(".")
            print([x.message for x in client.get_job_conditions(name=job_name) if x.type == "Succeeded"][0])
            break
        elif client.is_job_failed(name=job_name):
            print(".")
            print([x.message for x in client.get_job_conditions(name=job_name) if x.type == "Failed"][0])
            break
        else:
            print(f"PyTorchJob '{job_name}' status not available or no conditions found.")
            break

    except Exception as e:
        print(f"Error getting PyTorchJob status: {e}.")

    time.sleep(3)

In [None]:
# Get the job logs
print(client.get_job_logs(name=job_name)[0]["fraud-detection-master-0"])

### Delete jobs

After the PyTorchJob is finished, you can delete it.

In [None]:
client.delete_job(name=job_name)