# RHAI Features Test - TrainJob Submission

This notebook submits a TrainJob with RHAI features (progression tracking).
Uses shared PVC for HuggingFace cache across distributed training nodes.
Assertions are handled in the Go test.

In [None]:
def train_bloom():
    import os
    os.environ["HF_HOME"] = "/workspace/hf_cache"
    
    import torch
    import torch.distributed as dist
    from datasets import load_dataset
    from transformers import (
        AutoTokenizer, AutoModelForCausalLM, Trainer,
        TrainingArguments, DataCollatorForLanguageModeling,
    )

    # ========== Auto-detect accelerator and configure ==========
    if torch.cuda.is_available():
        # NVIDIA GPU or AMD ROCm (ROCm exposes as CUDA)
        device = torch.device("cuda")
        backend = "nccl"
        use_fp16 = True
        use_bf16 = False
        # Check for bf16 support (Ampere+ or ROCm)
        if torch.cuda.is_bf16_supported():
            use_fp16 = False
            use_bf16 = True
        accelerator = f"CUDA ({torch.cuda.get_device_name(0)})"
    elif hasattr(torch, "xpu") and torch.xpu.is_available():
        # Intel XPU
        device = torch.device("xpu")
        backend = "ccl"
        use_fp16 = True
        use_bf16 = False
        accelerator = "Intel XPU"
    else:
        # CPU fallback
        device = torch.device("cpu")
        backend = "gloo"
        use_fp16 = False
        use_bf16 = False
        accelerator = "CPU"

    print(f"Detected accelerator: {accelerator}")
    print(f"Using backend: {backend}, fp16={use_fp16}, bf16={use_bf16}")

    # ========== Initialize distributed ==========
    dist.init_process_group(backend=backend)
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    
    print(f"Distributed: WORLD_SIZE={world_size}, RANK={rank}, LOCAL_RANK={local_rank}")

    # Set device for this process (for multi-GPU)
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)

    # ========== Load model and tokenizer ==========
    model_name = "distilgpt2"
    
    if rank == 0:
        print(f"Downloading model: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)
        print("Model downloaded")
    dist.barrier()
    
    if rank != 0:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # ========== Load and prepare dataset ==========
    dataset_name = "yahma/alpaca-cleaned"
    
    if rank == 0:
        print(f"Downloading dataset: {dataset_name}")
        dataset = load_dataset(dataset_name, split="train[:100]")
        print("Dataset downloaded")
    dist.barrier()
    
    if rank != 0:
        dataset = load_dataset(dataset_name, split="train[:100]")
    
    dataset = dataset.train_test_split(test_size=0.2, shuffle=False)
    train_ds = dataset["train"]
    eval_ds = dataset["test"]

    def tokenize_function(examples):
        texts = [f"### Instruction:\n{i}\n\n### Response:\n{o}"
                 for i, o in zip(examples["instruction"], examples["output"])]
        return tokenizer(texts, padding="max_length", truncation=True, max_length=128)

    train_tokenized = train_ds.map(tokenize_function, batched=True, remove_columns=train_ds.column_names)
    eval_tokenized = eval_ds.map(tokenize_function, batched=True, remove_columns=eval_ds.column_names)

    # ========== Configure training arguments based on accelerator ==========
    training_args = TrainingArguments(
        output_dir="/workspace/checkpoints",
        num_train_epochs=5,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        eval_strategy="epoch",
        save_strategy="no",
        logging_steps=5,
        report_to="none",
        # Auto-configured based on accelerator
        fp16=use_fp16,
        bf16=use_bf16,
        # Only enable gradient checkpointing on GPU (saves VRAM)
        gradient_checkpointing=torch.cuda.is_available(),
        # Disable pin_memory on CPU to avoid warning
        dataloader_pin_memory=torch.cuda.is_available(),
        # Fix for distributed checkpoint loading
        save_safetensors=True,  # Avoid pickle device mapping issues
        save_only_model=True,   # Skip optimizer state (avoids CPU device tag issue)
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_tokenized,
        eval_dataset=eval_tokenized,
        data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
    )
    
    print(f"Starting training on {accelerator}...")
    trainer.train()

    dist.barrier()
    if rank == 0:
        print("Training is finished")
    dist.destroy_process_group()

In [None]:
import os
from kubernetes import client as k8s_client
from kubeflow.trainer import TrainerClient
from kubeflow.common.types import KubernetesBackendConfig

openshift_api_url = os.getenv("OPENSHIFT_API_URL", "")
token = os.getenv("NOTEBOOK_TOKEN", "")
namespace = os.getenv("NOTEBOOK_NAMESPACE", "default")
shared_pvc_name = os.getenv("SHARED_PVC_NAME", "shared-pvc")

print(f"API: {openshift_api_url}")
print(f"Namespace: {namespace}")
print(f"PVC: {shared_pvc_name}")

cfg = k8s_client.Configuration()
cfg.host = openshift_api_url
cfg.verify_ssl = False
cfg.api_key = {"authorization": f"Bearer {token}"}

api_client = k8s_client.ApiClient(cfg)

backend_cfg = KubernetesBackendConfig(
    client_configuration=api_client.configuration,
)

trainer_client = TrainerClient(backend_cfg)
print("TrainerClient initialized")

In [None]:
training_runtime_name = os.getenv("TRAINING_RUNTIME", "torch-distributed")
torch_runtime = trainer_client.get_runtime(training_runtime_name)
print(f"Got runtime: {torch_runtime.name}")

In [None]:
from kubeflow.trainer.rhai.transformers import TransformersTrainer
from kubeflow.trainer.options import PodTemplateOverrides, PodTemplateOverride, PodSpecOverride, ContainerOverride
import os

# Read feature flags from environment
enable_progression = os.getenv("ENABLE_PROGRESSION_TRACKING", "true").lower() == "true"
enable_checkpoint = os.getenv("ENABLE_JIT_CHECKPOINT", "false").lower() == "true"
checkpoint_output_dir = os.getenv("CHECKPOINT_OUTPUT_DIR", "/workspace/checkpoints")
checkpoint_save_strategy = os.getenv("CHECKPOINT_SAVE_STRATEGY", "epoch")
checkpoint_save_total_limit = int(os.getenv("CHECKPOINT_SAVE_TOTAL_LIMIT", "3"))

print(f"Progression tracking: {enable_progression}")
print(f"JIT Checkpoint: {enable_checkpoint}")

# Read GPU config from environment (passed from Go test)
gpu_resource_label = os.environ.get("GPU_RESOURCE_LABEL", "")

# Configure resources - GPU label tells k8s to schedule on GPU node
if gpu_resource_label:
    resources_per_node = {
        "cpu": 2, 
        "memory": "8Gi",
        gpu_resource_label: 1  # e.g., "nvidia.com/gpu": 1, "amd.com/gpu": 1
    }
    print(f"GPU mode: requesting {gpu_resource_label}: 1")
else:
    resources_per_node = {"cpu": 2, "memory": "8Gi"}
    print("CPU mode: no GPU requested")

# Build trainer config
trainer_kwargs = {
    "func": train_bloom,
    "func_args": {},
    "num_nodes": 2,
    "resources_per_node": resources_per_node, 
}

# Add progression tracking config (must explicitly set to False to disable SDK default)
trainer_kwargs["enable_progression_tracking"] = enable_progression
if enable_progression:
    trainer_kwargs["metrics_port"] = 28080
    trainer_kwargs["metrics_poll_interval_seconds"] = 8

# Add checkpointing if enabled
if enable_checkpoint:
    from kubeflow.trainer.rhai.transformers import PeriodicCheckpointConfig
    trainer_kwargs["enable_jit_checkpoint"] = True
    trainer_kwargs["output_dir"] = checkpoint_output_dir
    trainer_kwargs["periodic_checkpoint_config"] = PeriodicCheckpointConfig(
        save_strategy=checkpoint_save_strategy,
        save_total_limit=checkpoint_save_total_limit
    )

job_name = trainer_client.train(
    trainer=TransformersTrainer(**trainer_kwargs),
    runtime=torch_runtime,
    options=[
        PodTemplateOverrides(
            PodTemplateOverride(
                target_jobs=["node"],
                spec=PodSpecOverride(
                    volumes=[{"name": "workspace", "persistentVolumeClaim": {"claimName": shared_pvc_name}}],
                    containers=[ContainerOverride(
                        name="node",
                        volume_mounts=[{"name": "workspace", "mountPath": "/workspace"}]
                    )]
                )
            )
        )
    ]
)
print(f"TRAINJOB_NAME: {job_name}")

In [None]:
# Wait for job completion
trainer_client.wait_for_job_status(name=job_name, status={"Running"}, timeout=600)
trainer_client.wait_for_job_status(name=job_name, status={"Complete", "Failed"}, timeout=1200)

job = trainer_client.get_job(name=job_name)
print(f"Training job final status: {job.status}")

In [None]:
# Print job steps
for step in trainer_client.get_job(name=job_name).steps:
    print(f"Step: {step.name}, Status: {step.status}, Devices: {step.device} x {step.device_count}")

In [None]:
# Print job logs
for logline in trainer_client.get_job_logs(job_name, follow=False):
    print(logline)

In [None]:
# Notebook completed - Go test handles assertions
print("NOTEBOOK_STATUS: SUCCESS")