# RHAI Features Test - TrainJob Submission

This notebook submits a TrainJob with RHAI features (progression tracking).
Uses shared PVC for HuggingFace cache across distributed training nodes.
Assertions are handled in the Go test.

In [None]:
def train_bloom():
    """Training function for distributed training.
    
    Auto-detects data source:
    - If local data exists on shared PVC (pre-downloaded by notebook) > use it
    - Otherwise > download from HuggingFace Hub
    
    For disconnected environments:
    - The notebook pre-downloads model/dataset from S3 to shared PVC
    - Training pods just load from local paths (no S3 access needed)
    """
    import os
    os.environ["HF_HOME"] = "/workspace/hf_cache"
    
    import torch
    import torch.distributed as dist
    from datasets import load_dataset, load_from_disk
    from transformers import (
        AutoTokenizer, AutoModelForCausalLM, Trainer,
        TrainingArguments, DataCollatorForLanguageModeling,
    )

    # Local paths on shared PVC (pre-downloaded by notebook in disconnected mode)
    model_local_path = "/workspace/models/distilgpt2"
    dataset_local_path = "/workspace/datasets/alpaca-cleaned"
    
    # HuggingFace model/dataset names (for connected environments)
    model_name = "distilgpt2"
    dataset_name = "yahma/alpaca-cleaned"
    
    # Auto-detect: use local data if it exists
    use_local_data = os.path.exists(os.path.join(model_local_path, "config.json"))

    # ========== Auto-detect accelerator and configure ==========
    # GPU_TYPE env var overrides auto-detection (set by test framework)
    gpu_type = os.environ.get("GPU_TYPE", "").lower()
    force_cpu = gpu_type == "cpu"
    
    if force_cpu:
        print("CPU mode forced via GPU_TYPE=cpu")
        device = torch.device("cpu")
        backend = "gloo"
        use_fp16 = False
        use_bf16 = False
        accelerator = "CPU (forced)"
    elif torch.cuda.is_available() and torch.cuda.device_count() > 0:
        # NVIDIA GPU or AMD ROCm (ROCm exposes as CUDA)
        device = torch.device("cuda")
        backend = "nccl"
        use_fp16 = True
        use_bf16 = False
        # Check for bf16 support (Ampere+ or ROCm)
        if torch.cuda.is_bf16_supported():
            use_fp16 = False
            use_bf16 = True
        accelerator = f"CUDA ({torch.cuda.get_device_name(0)})"
    elif hasattr(torch, "xpu") and torch.xpu.is_available():
        # Intel XPU
        device = torch.device("xpu")
        backend = "ccl"
        use_fp16 = True
        use_bf16 = False
        accelerator = "Intel XPU"
    else:
        # CPU fallback
        device = torch.device("cpu")
        backend = "gloo"
        use_fp16 = False
        use_bf16 = False
        accelerator = "CPU"

    print(f"Detected accelerator: {accelerator}")
    print(f"Using backend: {backend}, fp16={use_fp16}, bf16={use_bf16}")

    # ========== Initialize distributed ==========
    dist.init_process_group(backend=backend)
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    
    print(f"Distributed: WORLD_SIZE={world_size}, RANK={rank}, LOCAL_RANK={local_rank}")

    # Set device for this process (for multi-GPU)
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)

    # ========== Load model and dataset ==========
    if use_local_data:
        # Local mode: load from shared PVC (pre-downloaded by notebook)
        print(f"Local mode: loading from shared PVC")
        print(f"Loading model from: {model_local_path}")
        tokenizer = AutoTokenizer.from_pretrained(model_local_path)
        model = AutoModelForCausalLM.from_pretrained(model_local_path)
        
        print(f"Loading dataset from: {dataset_local_path}")
        dataset = load_from_disk(dataset_local_path)
        # load_from_disk returns DatasetDict, get train split
        if hasattr(dataset, "keys") and "train" in dataset.keys():
            dataset = dataset["train"]
        # Take only first 100 samples for testing
        dataset = dataset.select(range(min(100, len(dataset))))
    else:
        # HuggingFace mode: download from internet
        print(f"HuggingFace mode: model={model_name}, dataset={dataset_name}")
        
        if rank == 0:
            print(f"Downloading model from HuggingFace: {model_name}")
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModelForCausalLM.from_pretrained(model_name)
            print("Model downloaded")
        dist.barrier()
        
        if rank != 0:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModelForCausalLM.from_pretrained(model_name)
        
        if rank == 0:
            print(f"Downloading dataset from HuggingFace: {dataset_name}")
            dataset = load_dataset(dataset_name, split="train[:100]")
            print("Dataset downloaded")
        dist.barrier()
        
        if rank != 0:
            dataset = load_dataset(dataset_name, split="train[:100]")
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    dataset = dataset.train_test_split(test_size=0.2, shuffle=False)
    train_ds = dataset["train"]
    eval_ds = dataset["test"]

    def tokenize_function(examples):
        texts = [f"### Instruction:\n{i}\n\n### Response:\n{o}"
                 for i, o in zip(examples["instruction"], examples["output"])]
        return tokenizer(texts, padding="max_length", truncation=True, max_length=128)

    train_tokenized = train_ds.map(tokenize_function, batched=True, remove_columns=train_ds.column_names)
    eval_tokenized = eval_ds.map(tokenize_function, batched=True, remove_columns=eval_ds.column_names)

    # ========== Configure training arguments based on accelerator ==========
    training_args = TrainingArguments(
        output_dir="/workspace/checkpoints",
        num_train_epochs=5,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        eval_strategy="epoch",
        save_strategy="no",
        logging_steps=5,
        report_to="none",
        # Auto-configured based on accelerator
        fp16=use_fp16,
        bf16=use_bf16,
        # Only enable gradient checkpointing on GPU (saves VRAM)
        gradient_checkpointing=torch.cuda.is_available(),
        # Disable pin_memory on CPU to avoid warning
        dataloader_pin_memory=torch.cuda.is_available(),
        # Fix for distributed checkpoint loading
        save_safetensors=True,  # Avoid pickle device mapping issues
        save_only_model=True,   # Skip optimizer state (avoids CPU device tag issue)
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_tokenized,
        eval_dataset=eval_tokenized,
        data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
    )
    
    print(f"Starting training on {accelerator}...")
    trainer.train()

    dist.barrier()
    if rank == 0:
        print("Training is finished")
    dist.destroy_process_group()

In [None]:
import os
import warnings
import urllib3

# Suppress SSL warnings for self-signed certs in disconnected environments
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
warnings.filterwarnings("ignore", message=".*Unverified HTTPS.*")
warnings.filterwarnings("ignore", category=urllib3.exceptions.InsecureRequestWarning)

from kubernetes import client as k8s_client
from kubeflow.trainer import TrainerClient
from kubeflow.common.types import KubernetesBackendConfig

openshift_api_url = os.getenv("OPENSHIFT_API_URL", "")
token = os.getenv("NOTEBOOK_TOKEN", "")
namespace = os.getenv("NOTEBOOK_NAMESPACE", "default")
shared_pvc_name = os.getenv("SHARED_PVC_NAME", "shared-pvc")

print(f"API: {openshift_api_url}")
print(f"Namespace: {namespace}")
print(f"PVC: {shared_pvc_name}")

cfg = k8s_client.Configuration()
cfg.host = openshift_api_url
cfg.verify_ssl = False
cfg.api_key = {"authorization": f"Bearer {token}"}

api_client = k8s_client.ApiClient(cfg)

backend_cfg = KubernetesBackendConfig(
    client_configuration=api_client.configuration,
)

trainer_client = TrainerClient(backend_cfg)
print("TrainerClient initialized")

In [None]:
training_runtime_name = os.getenv("TRAINING_RUNTIME")
if not training_runtime_name:
    raise RuntimeError("TRAINING_RUNTIME environment variable is required")

torch_runtime = trainer_client.get_runtime(training_runtime_name)
if torch_runtime is None:
    raise RuntimeError(f"Required runtime '{training_runtime_name}' not found")
print(f"Got runtime: {torch_runtime.name}")

In [None]:
# Pre-download model/dataset from S3 to shared PVC (disconnected environments only)
# Training pods will then load from these local paths
#
# NOTE: Shared PVC is mounted at different paths:
#   - Notebook pod: /opt/app-root/src (we write here)
#   - Training pods: /workspace (they read from here)
# Same PVC, different mount points - data is shared!

import os
import shutil

# Clean up any leftover checkpoints from previous runs to avoid resume conflicts
# (Notebook path = /opt/app-root/src, training pods see it as /workspace)
checkpoints_path = "/opt/app-root/src/checkpoints"
if os.path.exists(checkpoints_path):
    print(f"Cleaning up old checkpoints at {checkpoints_path}")
    shutil.rmtree(checkpoints_path)
    print("  ✅ Old checkpoints removed")

s3_endpoint = os.getenv("AWS_DEFAULT_ENDPOINT", "")
s3_access_key = os.getenv("AWS_ACCESS_KEY_ID", "")
s3_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY", "")
s3_bucket = os.getenv("AWS_STORAGE_BUCKET", "")
model_s3_prefix = os.getenv("MODEL_S3_PREFIX", "models/distilgpt2")
dataset_s3_prefix = os.getenv("DATASET_S3_PREFIX", "alpaca-cleaned-datasets")

# Notebook writes to /opt/app-root/src (its PVC mount point)
# Training pods will read from /workspace (their PVC mount point)
# Same underlying storage!
notebook_pvc_path = "/opt/app-root/src"
model_local_path = f"{notebook_pvc_path}/models/distilgpt2"
dataset_local_path = f"{notebook_pvc_path}/datasets/alpaca-cleaned"

use_s3 = bool(s3_endpoint and s3_bucket and s3_access_key and s3_secret_key)

if use_s3:
    print(f"S3 mode: downloading to shared PVC for training pods")
    print(f"  Endpoint: {s3_endpoint}")
    print(f"  Bucket: {s3_bucket}")
    
    import boto3
    from botocore.config import Config
    from pathlib import Path
    
    config = Config(
        signature_version="s3v4",
        s3={"addressing_style": "path"},
    )
    
    endpoint_url = s3_endpoint if s3_endpoint.startswith("http") else f"https://{s3_endpoint}"
    s3_client = boto3.client(
        "s3",
        endpoint_url=endpoint_url,
        aws_access_key_id=s3_access_key,
        aws_secret_access_key=s3_secret_key,
        config=config,
        verify=False,
    )
    
    def download_from_s3(s3_prefix: str, local_path: str):
        """Download files from S3/MinIO to local path."""
        print(f"  Downloading s3://{s3_bucket}/{s3_prefix}/ -> {local_path}")
        Path(local_path).mkdir(parents=True, exist_ok=True)
        
        paginator = s3_client.get_paginator("list_objects_v2")
        count = 0
        for page in paginator.paginate(Bucket=s3_bucket, Prefix=s3_prefix):
            for obj in page.get("Contents", []):
                key = obj["Key"]
                rel_path = key[len(s3_prefix):].lstrip("/")
                if not rel_path:
                    continue
                local_file = os.path.join(local_path, rel_path)
                os.makedirs(os.path.dirname(local_file), exist_ok=True)
                s3_client.download_file(s3_bucket, key, local_file)
                count += 1
        print(f"  ✅ Downloaded {count} files to {local_path}")
    
    # Download model if not already present
    if not os.path.exists(os.path.join(model_local_path, "config.json")):
        download_from_s3(model_s3_prefix, model_local_path)
    else:
        print(f"  Model already exists at {model_local_path}, skipping")
    
    # Download dataset if not already present
    if not os.path.exists(os.path.join(dataset_local_path, "dataset_dict.json")):
        download_from_s3(dataset_s3_prefix, dataset_local_path)
    else:
        print(f"  Dataset already exists at {dataset_local_path}, skipping")
    
    print("✅ S3 download complete - training pods will load from shared PVC")
else:
    print("HuggingFace mode: training pods will download directly from HF Hub")


In [None]:
from kubeflow.trainer.rhai.transformers import TransformersTrainer
from kubeflow.trainer.options import PodTemplateOverrides, PodTemplateOverride, PodSpecOverride, ContainerOverride
import os

# Read feature flags from environment
enable_progression = os.getenv("ENABLE_PROGRESSION_TRACKING", "true").lower() == "true"
enable_checkpoint = os.getenv("ENABLE_JIT_CHECKPOINT", "false").lower() == "true"
checkpoint_output_dir = os.getenv("CHECKPOINT_OUTPUT_DIR", "/workspace/checkpoints")
checkpoint_save_strategy = os.getenv("CHECKPOINT_SAVE_STRATEGY", "epoch")
checkpoint_save_total_limit = int(os.getenv("CHECKPOINT_SAVE_TOTAL_LIMIT", "3"))

print(f"Progression tracking: {enable_progression}")
print(f"JIT Checkpoint: {enable_checkpoint}")



# Read GPU and multi-node config from environment (passed from Go test)
gpu_resource_label = os.environ.get("GPU_RESOURCE_LABEL", "")
gpu_type = os.environ.get("GPU_TYPE", "cpu")
num_nodes = int(os.environ.get("NUM_NODES", "2"))
num_gpus_per_node = int(os.environ.get("NUM_GPUS_PER_NODE", "1"))

print(f"Training config: num_nodes={num_nodes}, num_gpus_per_node={num_gpus_per_node}, gpu_type={gpu_type}")

# Configure resources - GPU label tells k8s to schedule on GPU node
if gpu_resource_label:
    resources_per_node = {
        "cpu": 2, 
        "memory": "8Gi",
        gpu_resource_label: num_gpus_per_node  # e.g., "nvidia.com/gpu": 2
    }
    print(f"GPU mode: requesting {gpu_resource_label}: {num_gpus_per_node}")
else:
    resources_per_node = {"cpu": 2, "memory": "8Gi"}
    print("CPU mode: no GPU requested")

# Build env vars to pass to training pods (GPU_TYPE tells train_bloom() to force CPU mode)
training_env = {"GPU_TYPE": gpu_type}

# Build trainer config
trainer_kwargs = {
    "func": train_bloom,
    "num_nodes": num_nodes,
    "resources_per_node": resources_per_node,
    "env": training_env,
}

# Add progression tracking config (must explicitly set to False to disable SDK default)
trainer_kwargs["enable_progression_tracking"] = enable_progression
if enable_progression:
    trainer_kwargs["metrics_port"] = 28080
    trainer_kwargs["metrics_poll_interval_seconds"] = 8

# Add checkpointing if enabled
if enable_checkpoint:
    from kubeflow.trainer.rhai.transformers import PeriodicCheckpointConfig
    trainer_kwargs["enable_jit_checkpoint"] = True
    trainer_kwargs["output_dir"] = checkpoint_output_dir
    trainer_kwargs["periodic_checkpoint_config"] = PeriodicCheckpointConfig(
        save_strategy=checkpoint_save_strategy,
        save_total_limit=checkpoint_save_total_limit
    )

job_name = trainer_client.train(
    trainer=TransformersTrainer(**trainer_kwargs),
    runtime=torch_runtime,
    options=[
        PodTemplateOverrides(
            PodTemplateOverride(
                target_jobs=["node"],
                spec=PodSpecOverride(
                    volumes=[{"name": "workspace", "persistentVolumeClaim": {"claimName": shared_pvc_name}}],
                    containers=[ContainerOverride(
                        name="node",
                        volume_mounts=[{"name": "workspace", "mountPath": "/workspace"}]
                    )]
                )
            )
        )
    ]
)
print(f"TRAINJOB_NAME: {job_name}")

In [None]:
# Wait for job completion
trainer_client.wait_for_job_status(name=job_name, status={"Running"}, timeout=600)
trainer_client.wait_for_job_status(name=job_name, status={"Complete", "Failed"}, timeout=1200)

job = trainer_client.get_job(name=job_name)
print(f"Training job final status: {job.status}")

In [None]:
# Print job steps
for step in trainer_client.get_job(name=job_name).steps:
    print(f"Step: {step.name}, Status: {step.status}, Devices: {step.device} x {step.device_count}")

In [None]:
# Print job logs
for logline in trainer_client.get_job_logs(job_name, follow=False):
    print(logline)

In [None]:
# Notebook completed - Go test handles assertions
print("NOTEBOOK_STATUS: SUCCESS")