In [1]:
!pip install --upgrade pip
# !pip install --no-deps --upgrade --force-reinstall --no-cache-dir "git+https://github.com/briangallagher/sdk@training-hub"
!pip install --upgrade --force-reinstall --no-cache-dir "git+https://github.com/briangallagher/sdk@training-hub"
!pip install datasets

Collecting pip
  Downloading pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.2-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m62.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.2
    Uninstalling pip-24.2:
      Successfully uninstalled pip-24.2
Successfully installed pip-25.2
Collecting git+https://github.com/briangallagher/sdk@training-hub
  Cloning https://github.com/briangallagher/sdk (to revision training-hub) to /tmp/pip-req-build-2qe5rn9z
  Running command git clone --filter=blob:none --quiet https://github.com/briangallagher/sdk /tmp/pip-req-build-2qe5rn9z
  Running command git checkout -b training-hub --track origin/training-hub
  Switched to a new branch 'training-hub'
  branch 'training-hub' set up to track 'origin/training-hub'.
  Resolved https://github.com/briangallagher/sdk to commit 96ba7493de9d1678f7640

In [2]:
import os
import json
from datasets import load_dataset

# Prepare workspace directories
base_dir = "/opt/app-root/src"
data_dir = os.path.join(base_dir, "data")
outputs_dir = os.path.join(base_dir, "outputs")
ckpt_dir = os.path.join(base_dir, "checkpoints")

for d in [data_dir, outputs_dir, ckpt_dir]:
    os.makedirs(d, exist_ok=True)
    os.chmod(d, 0o777)
    print(f"[PY] Ensured directory exists and writable: {d}")

dataset_path = os.path.join(data_dir, "dataset.jsonl")

# Prepare workspace directories
ds = load_dataset("Open-Orca/OpenOrca", split="train")

def convert(example):
    msgs = [{"role": "system", "content": example['system_prompt']}]
    if example["question"]:
        msgs.append({"role": "user", "content": example["question"]})
    msgs.append({"role": "assistant", "content": example["response"]})
    return {"messages": msgs}

sam = ds.shuffle(seed=42).select(range(1000))
alp = sam.map(convert)
alp.to_json(dataset_path, orient="records", lines=True)
print(f"[PY] Finished writing dataset: {d}")
print(f"[PY] Created dataset file: {dataset_path}")

[PY] Ensured directory exists and writable: /opt/app-root/src/data
[PY] Ensured directory exists and writable: /opt/app-root/src/outputs
[PY] Ensured directory exists and writable: /opt/app-root/src/checkpoints


Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

[PY] Finished writing dataset: /opt/app-root/src/checkpoints
[PY] Created dataset file: /opt/app-root/src/data/dataset.jsonl


In [3]:
from kubeflow.trainer import TrainerClient, TrainingHubTrainer, TrainingHubAlgorithms
from kubeflow_trainer_api import models

client = TrainerClient()
print(client)

<kubeflow.trainer.api.trainer_client.TrainerClient object at 0x7f60cfd06dd0>


In [4]:
# NOTE: Needed to add cluster role binding to read the cluster scoped runtimes

for runtime in client.list_runtimes():
    if runtime.name == "training-hub-sft":
        sft_runtime = runtime
        print("Found runtime: " + str(sft_runtime))

Runtime deepspeed-distributed must have trainer.kubeflow.org/framework label.
Runtime mlx-distributed must have trainer.kubeflow.org/framework label.
Runtime mpi-distributed must have trainer.kubeflow.org/framework label.
Runtime torch-distributed must have trainer.kubeflow.org/framework label.
Runtime torchtune-llama3.2-3b must have trainer.kubeflow.org/framework label.
Runtime torchtune-llama3.2-3b-brian must have trainer.kubeflow.org/framework label.


Found runtime: Runtime(name='training-hub-sft', trainer=RuntimeTrainer(trainer_type=<TrainerType.TRAINING_HUB_TRAINER: 'TrainingHubTrainer'>, framework='training-hub', num_nodes=1, device='Unknown', device_count='Unknown'), pretrained_model=None)


In [5]:
args = {
    "model_path":"Qwen/Qwen2.5-0.5B",
    "data_path": "/opt/app-root/src/data/dataset.jsonl",
    "ckpt_output_dir": "/opt/app-root/src/checkpoints",
    "data_output_dir": "/opt/app-root/src/outputs",
    "num_epochs": 1,
    "effective_batch_size": 128,
    "max_tokens_per_gpu": 2048,
    "learning_rate": 1e-05,
    "max_seq_len": 512,
    "max_batch_len": 512,
    "save_samples": 0,
    "warmup_steps": 100,
    "checkpoint_at_epoch": True,
    "accelerate_full_state_at_epoch": True,
    "rdzv_id": 1,
    "disable_flash_attn": True,
    "packing": False,
    "enable_multipack": False,
    "fp16": True,
    "bf16": False,
    "gradient_checkpointing": True,
    "distributed_training_framework": "fsdp",
    "fsdp_sharding_strategy": "SHARD_GRAD_OP",
    "disable_multipack": True,
    "dtype": "float16",
    "nproc_per_node": 1,
    "nnodes": 2,
}

# Kubernetes volumes and mounts (PVC example)
volumes = [
    models.IoK8sApiCoreV1Volume(
        name="example",
        persistent_volume_claim=models.IoK8sApiCoreV1PersistentVolumeClaimVolumeSource(
            claim_name="example"
        ),
    ),
]
volume_mounts = [
    models.IoK8sApiCoreV1VolumeMount(
        name="example",
        mount_path="/opt/app-root/src/",
        read_only=False,
    ),
]

job_name = client.train(
    trainer=TrainingHubTrainer(
        algorithm=TrainingHubAlgorithms.SFT,
        func_args=args,
        packages_to_install=["training-hub"],
        volumes=volumes,
        volume_mounts=volume_mounts
    ),
    runtime=sft_runtime,
)