## TRL(Transformer Reinforcement Learning) Training with Kubeflow SDK and Advanced Checkpointing

This notebook demonstrates how to use the Kubeflow Trainer SDK to create and manage TrainJobs

### Features Demonstrated
- **Kubeflow SDK Integration**: Programmatic TrainJob creation and management
- **Checkpointing**: Controller-managed resume/suspended compatibility for model checkpoints
- **TRL SFTTrainer**: Supervised fine-tuning using Peft-LoRA with GPT-2 and Alpaca dataset for instruction following
- **Distributed Training**: Multi-node Multi-GPU coordination
- **Compute resource pre-requisite for this demo** : 
    This demo can run on -
    - CPUs based training using GLOO backend (default configuration)
    - GPUs(NVIDIA/AMD) based training using NCCL backend
        - Respective training images (update in [torch-cuda-custom](./cluster_training_runtime.yaml)):
            - quay.io/modh/training:py311-cuda124-torch251
            - quay.io/modh/training:py311-rocm62-torch251
    - Multi-node Multi-GPU distributed training using Trainer V2 MlPolicies (NumNodes/NProcPerNodes)

### Prerequisites
- Persistent volume storage with RWX(ReadWriteManyAccess) : [workspace](workspace-checkpoint-storage)
- ClusterTrainingRuntime :  [torch-cuda-custom](./cluster_training_runtime.yaml)

### Sample scripts
- [mnist.py](./scripts/mnist.py)
- [trl_training.py](./scripts/trl_training.py)
- _oc apply -k examples/kft-v2/manifests_

### References
- [Kubeflow Trainer SDK](https://github.com/kubeflow/sdk)
- [TRL Documentation](https://huggingface.co/docs/trl/)
- [PEFT Documentation](https://huggingface.co/docs/peft/)

In [None]:
# Install Kubeflow SDK from source github main branch
%pip install kubeflow

In [2]:
%pip show kubeflow

Name: kubeflow
Version: 0.1.0
Summary: Kubeflow Python SDK to manage ML workloads and to interact with Kubeflow APIs.
Home-page: https://github.com/kubeflow/sdk
Author: 
Author-email: The Kubeflow Authors <kubeflow-discuss@googlegroups.com>
License: 
Location: /opt/app-root/lib64/python3.12/site-packages
Requires: kubeflow-trainer-api, kubernetes, pydantic
Required-by: 
Note: you may need to restart the kernel to use updated packages.


### Define TRL Training Function
- Progress file writer (callbacks)
- Distributed checkpoint coordination
- Automated model checkpointing by SIGTERM signal handling

### Create TrainJob Using Kubeflow SDK
Now we'll use the Kubeflow SDK to create a TrainJob
- Training arguments
- *CustomTrainer* with the TRL training function
- *Initializer* for dataset and model (V2 initializers)

In [3]:
from kubeflow.trainer import CustomTrainer, Initializer

training_env_args = {
    "PYTHONUNBUFFERED": "1",
    "NCCL_DEBUG": "INFO",
    "TORCH_DISTRIBUTED_DEBUG": "INFO",
    "PYTHONPATH": "/tmp/lib:$PYTHONPATH",

    # Training hyperparameters
    "LEARNING_RATE": "5e-5",
    "BATCH_SIZE": "1",
    "MAX_EPOCHS": "3",
    "WARMUP_STEPS": "5",
    "EVAL_STEPS": "3",
    "SAVE_STEPS": "2",
    "LOGGING_STEPS": "2",
    "GRADIENT_ACCUMULATION_STEPS": "2",
    
    # Model configuration
    "MODEL_NAME": "gpt2",
    "LORA_R": "16",
    "LORA_ALPHA": "32",
    "LORA_DROPOUT": "0.1",
    "MAX_SEQ_LENGTH": "512",
    
    # Dataset configuration
    "DATASET_NAME": "tatsu-lab/alpaca",
    "DATASET_TRAIN_SPLIT": "train[:500]",
    "DATASET_TEST_SPLIT": "train[500:520]",
    
    # Checkpointing configuration
    "CHECKPOINT_URI": "/workspace/checkpoints",
    "TRAINJOB_PROGRESSION_FILE_PATH": "/tmp/training_progression.json",
    
    # Cache directories
    "PYTHONUNBUFFERED": "1",
    "TRANSFORMERS_CACHE": "/workspace/cache/transformers",
    "HF_HOME": "/workspace/cache",
    "HF_DATASETS_CACHE": "/workspace/cache/datasets",
    
    # Distributed training debug
    "NCCL_DEBUG": "INFO",
    "NCCL_DEBUG_SUBSYS": "ALL",
    "NCCL_SOCKET_IFNAME": "eth0",
    "NCCL_IB_DISABLE": "1",
    "NCCL_P2P_DISABLE": "1",
    "NCCL_TREE_THRESHOLD": "0",
    "TORCH_DISTRIBUTED_DEBUG": "INFO",
    "TORCH_SHOW_CPP_STACKTRACES": "1",
}

from trl_training import trl_train

# Create CustomTrainer configuration
custom_trainer = CustomTrainer(
    func=trl_train,
    num_nodes=2,  # Distributed training across 2 nodes
    resources_per_node={
        "cpu": "2",
        "memory": "4Gi",
        # Uncomment for GPU training:
        # "nvidia.com/gpu": "1",
    },
    packages_to_install=[
        "transformers[torch]",
        "trl", 
        "peft", 
        "datasets", 
        "accelerate",
        "torch",
        "numpy"
        " --target=/tmp/lib"
        " --verbose"
    ],
    env=training_env_args
)
from kubeflow.trainer.types import types

# Configure Initializers
initializer = Initializer(
    dataset=types.HuggingFaceDatasetInitializer(
        storage_uri="hf://tatsu-lab/alpaca"
    ),
    model=types.HuggingFaceModelInitializer(
        storage_uri="hf://gpt2"
    )
)

print("Training configuration initialised!")

Training configuration initialised!


### Initialize Trainer Client
Use token authentication to intialize a training client and list available runtimes

In [12]:
from kubeflow.trainer import TrainerClient
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
from kubernetes import client

api_server = "<api-server-url>"
token = "<auth-token>"

configuration = client.Configuration()
configuration.host = api_server
configuration.api_key = {"authorization": f"Bearer {token}"}

# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA
configuration.verify_ssl = False

api_client = client.ApiClient(configuration)
trainer_client = TrainerClient(backend_config= KubernetesBackendConfig(client_configuration=api_client.configuration))

print("Available runtimes :", len(trainer_client.list_runtimes()))
for r in trainer_client.list_runtimes():
    print(f"- {r.name}")

Available runtimes : 5
- torch-cuda-241
- torch-cuda-251
- torch-cuda-custom
- torch-rocm-241
- torch-rocm-251


### Create TrainJob
Create a TrainJob using resources declared above - 
- Custom trainer
- Dataset & Model initailisers 

In [7]:
job_name = trainer_client.train(
    trainer=custom_trainer,
    initializer=initializer,
    runtime=trainer_client.get_runtime("torch-cuda-custom")
)
print("Trainjob submitted!!")

Trainjob submitted!!


![pods](./docs/trainjobs_jobsets.png)


![jobs](./docs/jobs.png)

### Start monitoring - View Training Logs 

![trainjob_pods](./docs/trainjob_pods.png)

In [None]:
# Get training logs
try:    
    # Get logs from the training nodes
    logs = trainer_client.get_job_logs(job_name, follow=False)
    
    print("\n" + "="*80)
    print("TRAINING LOGS")
    print("="*80)
    
    # Display logs - logs is a generator, not a dict
    for log_line in logs:
        if log_line.strip():
            print(log_line)
    
    print("\n" + "="*80)
    
except Exception as e:
    print(f"Error getting logs: {e}")
    print("Note: Logs may not be available yet if training is still starting up")

### Cleanup resources

In [13]:
# Clean up the TrainJob when done
def cleanup_trainjob():
    """Clean up the TrainJob using Kubeflow SDK"""
    try:
        trainer_client.delete_job(job_name)
        print(f"TrainJob '{job_name}' deleted successfully")
    except Exception as e:
        print(f"Error deleting TrainJob: {e}")

# Get final job status before cleanup
try:
    final_job = trainer_client.get_job(job_name)
    print(f"Final TrainJob Status:")
    print(f"   Name: {final_job.name}")
    print(f"   Status: {final_job.status}")
    print(f"   Created: {final_job.creation_timestamp}")
    print(f"   Nodes: {final_job.num_nodes}")
    print(f"   Runtime: {final_job.runtime.name}")
    
    if final_job.steps:
        print(f"   Steps:")
        for step in final_job.steps:
            print(f"     - {step.name}: {step.status}")
        print()
        cleanup_trainjob()
            
except Exception as e:
    print(f"Error getting final job status: {e}")

Final TrainJob Status:
   Name: d5648a3bd444
   Status: Complete
   Created: 2025-10-06 15:08:34+00:00
   Nodes: 2
   Runtime: torch-cuda-custom
   Steps:
     - dataset-initializer: Succeeded
     - model-initializer: Succeeded
     - node-0: Succeeded
     - node-1: Succeeded

TrainJob 'd5648a3bd444' deleted successfully


TrainJob 'hbfe180e23f8' deleted successfully
