In [None]:
# Standard library imports
import logging
import os
import sys
import time
from io import StringIO

In [None]:
from kubernetes import client as k8s, config as k8s_config
# Edit to match your specific settings
api_server = os.getenv("OPENSHIFT_API_URL")
token = os.getenv("NOTEBOOK_USER_TOKEN")
if not api_server or not token:
    raise RuntimeError("OPENSHIFT_API_URL and NOTEBOOK_USER_TOKEN environment variables are required")
PVC_NAME = os.getenv("SHARED_PVC_NAME", "shared")

configuration = k8s.Configuration()
configuration.host = api_server
# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA
configuration.verify_ssl = False
configuration.api_key = {"authorization": f"Bearer {token}"}
api_client = k8s.ApiClient(configuration)

PVC_MOUNT_PATH = "/opt/app-root/src"

In [None]:
import os
import gzip
import shutil
import socket

import boto3
from botocore.config import Config as BotoConfig
from botocore.exceptions import ClientError

# --- Global networking safety net: cap all socket operations ---
socket.setdefaulttimeout(10)  # seconds

# Notebook's PVC mount path (per Notebook CR). Training pods will mount the same PVC at /opt/app-root/src
PVC_NOTEBOOK_PATH = "/opt/app-root/src"
DATASET_ROOT_NOTEBOOK = PVC_NOTEBOOK_PATH
TABLE_GPT_DIR = os.path.join(DATASET_ROOT_NOTEBOOK, "table-gpt-data", "train")
MODEL_DIR = os.path.join(DATASET_ROOT_NOTEBOOK, "Qwen", "Qwen2.5-1.5B-Instruct")
os.makedirs(TABLE_GPT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

# Env config for S3/MinIO
s3_endpoint = os.getenv("AWS_DEFAULT_ENDPOINT", "")
s3_access_key = os.getenv("AWS_ACCESS_KEY_ID", "")
s3_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY", "")
s3_bucket = os.getenv("AWS_STORAGE_BUCKET", "")
s3_prefix = os.getenv("AWS_STORAGE_BUCKET_LORA_SFT_DIR", "")

data_download_successful = False

def stream_download(s3, bucket, key, dst):
    """Download an object from S3/MinIO using streaming reads."""
    print(f"[notebook] STREAM download s3://{bucket}/{key} -> {dst}")
    t0 = time.time()
    try:
        resp = s3.get_object(Bucket=bucket, Key=key)
    except ClientError as e:
        print(f"[notebook] CLIENT ERROR for {key}: {e.response.get('Error', {})}")
        return False
    except Exception as e:
        print(f"[notebook] OTHER ERROR for {key}: {e}")
        return False

    body = resp["Body"]
    try:
        with open(dst, "wb") as f:
            while True:
                try:
                    chunk = body.read(1024 * 1024)
                except socket.timeout as e:
                    print(f"[notebook] socket.timeout while reading {key}: {e}")
                    return False
                if not chunk:
                    break
                f.write(chunk)
    except Exception as e:
        print(f"[notebook] ERROR writing to {dst}: {e}")
        return False

    t1 = time.time()
    print(f"[notebook] DONE stream {key} in {t1 - t0:.2f}s")
    return True

# Try S3 download first, fall back to HuggingFace if not configured or fails
if s3_endpoint and s3_bucket:
    try:
        endpoint_url = s3_endpoint if s3_endpoint.startswith("http") else f"https://{s3_endpoint}"
        prefix = (s3_prefix or "").strip("/")
        
        print(f"[notebook] S3 configured: endpoint={endpoint_url}, bucket={s3_bucket}, prefix={prefix or '<root>'}")
        
        boto_cfg = BotoConfig(
            signature_version="s3v4",
            s3={"addressing_style": "path"},
            retries={"max_attempts": 1, "mode": "standard"},
            connect_timeout=5,
            read_timeout=10,
        )
        
        # Create S3/MinIO client
        s3 = boto3.client(
            "s3",
            endpoint_url=endpoint_url,
            aws_access_key_id=s3_access_key,
            aws_secret_access_key=s3_secret_key,
            config=boto_cfg,
            verify=False,
        )
        
        paginator = s3.get_paginator("list_objects_v2")
        pulled_any = False
        file_count = 0
        
        print(f"[notebook] Starting S3 download from prefix: {prefix}")
        for page in paginator.paginate(Bucket=s3_bucket, Prefix=prefix or ""):
            contents = page.get("Contents", [])
            if not contents:
                print(f"[notebook] No contents found in this page")
                continue
                
            print(f"[notebook] Found {len(contents)} objects in this page")
            
            for obj in contents:
                key = obj["Key"]
                file_count += 1
                
                # Skip "directory markers"
                if key.endswith("/"):
                    print(f"[notebook] Skipping directory marker: {key}")
                    continue
                
                # Determine relative path under prefix for local storage
                rel = key[len(prefix):].lstrip("/") if prefix else key
                print(f"[notebook] Processing key={key}, rel={rel}")
                
                # Route to appropriate directory based on content type
                if "table-gpt" in rel.lower() or rel.endswith(".jsonl"):
                    dst = os.path.join(TABLE_GPT_DIR, os.path.basename(rel))
                    print(f"[notebook] Routing to dataset dir: {dst}")
                elif "qwen" in rel.lower() or any(rel.endswith(ext) for ext in [".bin", ".json", ".model", ".safetensors", ".txt"]):
                    dst = os.path.join(MODEL_DIR, rel.split("Qwen2.5-1.5B-Instruct/")[-1] if "Qwen2.5-1.5B-Instruct" in rel else os.path.basename(rel))
                    print(f"[notebook] Routing to model dir: {dst}")
                else:
                    dst = os.path.join(DATASET_ROOT_NOTEBOOK, rel)
                    print(f"[notebook] Routing to default dir: {dst}")
                
                os.makedirs(os.path.dirname(dst), exist_ok=True)
                
                # Download only if missing
                if not os.path.exists(dst):
                    ok = stream_download(s3, s3_bucket, key, dst)
                    if not ok:
                        print(f"[notebook] Download failed for {key}")
                        continue
                    pulled_any = True
                else:
                    print(f"[notebook] Skipping existing file {dst}")
                    pulled_any = True
                
                # If the file is .gz, decompress and remove the .gz
                if dst.endswith(".gz") and os.path.exists(dst):
                    out_path = os.path.splitext(dst)[0]
                    if not os.path.exists(out_path):
                        print(f"[notebook] Decompressing {dst} -> {out_path}")
                        try:
                            with gzip.open(dst, "rb") as f_in, open(out_path, "wb") as f_out:
                                shutil.copyfileobj(f_in, f_out)
                        except Exception as e:
                            print(f"[notebook] Failed to decompress {dst}: {e}")
                        else:
                            try:
                                os.remove(dst)
                            except Exception:
                                pass
        
        if pulled_any:
            print(f"[notebook] ✓ S3 download successful. Processed {file_count} files")
            data_download_successful = True
        else:
            print(f"[notebook] ✗ S3 download found no files to download")
        
    except Exception as e:
        print(f"[notebook] ✗ S3 fetch failed: {e}")
        import traceback
        traceback.print_exc()
        print("[notebook] Will attempt HuggingFace fallback...")
else:
    print("[notebook] S3 not configured (missing endpoint or bucket env vars)")

# Fallback to HuggingFace if S3 was not configured or failed (requires internet)
if not data_download_successful:
    print("[notebook] Attempting HuggingFace dataset download (requires internet)...")
    try:
        import json
        import random
        from datasets import load_dataset
        
        print("[notebook] Loading Table-GPT dataset from HuggingFace...")
        dataset = load_dataset("LipengCS/Table-GPT", "All")
        
        train_data = dataset["train"]
        print(f"[notebook] Original training set size: {len(train_data)}")
        
        # Create a random subset of 100 samples
        random.seed(42)
        subset_indices = random.sample(range(len(train_data)), min(100, len(train_data)))
        subset_data = train_data.select(subset_indices)
        
        print(f"[notebook] Subset size: {len(subset_data)}")
        
        # Save the subset to a JSONL file
        output_file = os.path.join(TABLE_GPT_DIR, "train_All_100.jsonl")
        with open(output_file, "w") as f:
            for example in subset_data:
                f.write(json.dumps(example) + "\n")
        
        print(f"[notebook] ✓ HuggingFace download successful. Subset saved to {output_file}")
        data_download_successful = True
        
    except Exception as hf_error:
        print(f"[notebook] ✗ HuggingFace download failed: {hf_error}")
        import traceback
        traceback.print_exc()
        raise RuntimeError(
            "Failed to download dataset from both S3 and HuggingFace. "
            "In disconnected environments, ensure S3/MinIO is configured with the required data. "
            "In connected environments, check your internet connection and credentials."
        ) from hf_error

# Verify dataset file exists
dataset_file = os.path.join(TABLE_GPT_DIR, "train_All_100.jsonl")
if os.path.exists(dataset_file):
    print(f"[notebook] ✓ Dataset ready: {dataset_file}")
else:
    raise RuntimeError(f"Dataset file not found: {dataset_file}")

# Verify model directory has files (model will be downloaded during training if not present)
if os.path.exists(MODEL_DIR) and os.listdir(MODEL_DIR):
    print(f"[notebook] ✓ Model files ready in: {MODEL_DIR}")
    print(f"[notebook] Model files: {os.listdir(MODEL_DIR)[:5]}...")  # Show first 5 files
else:
    print(f"[notebook] Note: Model directory is empty: {MODEL_DIR}")
    print("[notebook] Training will download model from HuggingFace during execution")

In [None]:
# Determine model path based on whether S3 download succeeded
import os
LOCAL_MODEL_PATH = "/opt/app-root/src/Qwen/Qwen2.5-1.5B-Instruct"
HUGGINGFACE_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"

# Check if model was downloaded from S3
model_downloaded = os.path.exists(LOCAL_MODEL_PATH) and len(os.listdir(LOCAL_MODEL_PATH)) > 0

if model_downloaded:
    model_path_to_use = LOCAL_MODEL_PATH
    print(f"✓ Using local model from S3: {model_path_to_use}")
else:
    model_path_to_use = HUGGINGFACE_MODEL_ID  
    print(f"✓ Using HuggingFace model ID: {model_path_to_use}")

params = {
    ###########################################################################
    # 🤖 Model + Data Paths                                                   #
    ###########################################################################
    "model_path": model_path_to_use,
    "data_path": "/opt/app-root/src/table-gpt-data/train/train_All_100.jsonl",
    "ckpt_output_dir": "/opt/app-root/src/checkpoints-logs-dir",
    "data_output_path": "/opt/app-root/src/lora-sft-json/_data",
    ############################################################################
    # 🏋️‍♀️ Training Hyperparameters                                              #
    ############################################################################
    "effective_batch_size": 128,
    "learning_rate": 1.0e-4,  # LoRA typically uses higher learning rate than full fine-tuning
    "num_epochs": 1,
    "lr_scheduler": "cosine",
    "warmup_steps": 0,
    "seed": 42,
    ############################################################################
    # 🔧 LoRA Configuration                                                    #
    ############################################################################
    "lora_r": 32,              # LoRA rank (from lora_example.py default)
    "lora_alpha": 64,          # LoRA alpha parameter (from lora_example.py default)
    "lora_dropout": 0.0,       # LoRA dropout (optimized for Unsloth)
    ###########################################################################
    # 🏎️ Performance Hyperparameters                                          #
    ###########################################################################
    "max_tokens_per_gpu": 32000,
    "max_seq_len": 2048,
    ############################################################################
    # 💾 Checkpointing Settings                                                #
    ############################################################################
    "save_final_checkpoint": True,
    "checkpoint_at_epoch": False,
}


In [None]:
from kubeflow.trainer import TrainerClient
from kubeflow.trainer.rhai import TrainingHubAlgorithms
from kubeflow.trainer.rhai import TrainingHubTrainer
from kubeflow_trainer_api import models
from kubeflow.common.types import KubernetesBackendConfig

backend_cfg = KubernetesBackendConfig(
    client_configuration=api_client.configuration,   # <— key part
)

client = TrainerClient(backend_cfg)
print(client)

In [None]:
th_runtime = None
for runtime in client.list_runtimes():
    if runtime.name == "training-hub03-cuda128-torch28-py312":
        th_runtime = runtime
        print("Found runtime: " + str(th_runtime))
        break

if th_runtime is None:
    raise RuntimeError("Required runtime 'training-hub03-cuda128-torch28-py312' not found")

In [None]:
from kubeflow.trainer.options.kubernetes import (
    PodTemplateOverrides,
    PodTemplateOverride,
    PodSpecOverride,
    ContainerOverride,
)

cache_root = "/opt/app-root/src/.cache/huggingface"
triton_cache = "/opt/app-root/src/.triton"

job_name = client.train(
    trainer=TrainingHubTrainer(
        algorithm=TrainingHubAlgorithms.LORA_SFT,
        func_args=params,
        env={ 
            "HF_HOME": cache_root,
            "TRITON_CACHE_DIR": triton_cache,
            "XDG_CACHE_HOME": "/opt/app-root/src/.cache",
            "NCCL_DEBUG": "INFO",
        },
    ),
    options=[
        PodTemplateOverrides(
            PodTemplateOverride(
                target_jobs=["node"],
                spec=PodSpecOverride(
                    volumes=[
                        {"name": "work", "persistentVolumeClaim": {"claimName": PVC_NAME}},
                    ],
                    containers=[
                        ContainerOverride(
                            name="node", 
                            volume_mounts=[
                                {"name": "work", "mountPath": "/opt/app-root/src", "readOnly": False},
                            ],
                        )
                    ],
                ),
            )
        )
    ],
    runtime=th_runtime,
)

In [None]:
# Wait for the running status, then wait for completion or failure
# Using reasonable timeout for LORA-SFT training
client.wait_for_job_status(name=job_name, status={"Running"}, timeout=300)
client.wait_for_job_status(name=job_name, status={"Complete", "Failed"}, timeout=1800)  # 30 minutes for training

# Get job details and logs
job = client.get_job(name=job_name)
pod_logs, _ = client.get_job_logs(name=job_name, follow=False)

# Flatten all pod logs into a single list of lines
logs = []
for log_text_content in pod_logs.values():
    logs.extend(str(log_text_content).splitlines())

log_text = "\n".join(logs)

print(f"Training job final status: {job.status}")

# Check 1: Job status must not be \"Failed\"  
if job.status == "Failed":
    print(f"ERROR: Training job '{job_name}' has Failed status")
    print("Last 30 lines of logs:")
    for line in logs[-30:]:
        print(line)
    raise RuntimeError(f"Training job '{job_name}' failed")

# Check 2: Look for the training completion message in logs
# This is critical because the training script may catch exceptions and exit 0
if "[PY] LORA_SFT training complete. Result=" not in log_text:
    print(f"ERROR: Training completion message not found in logs")
    print("Last 50 lines of logs:")
    for line in logs[-50:]:
        print(line)
    raise RuntimeError(f"Training did not complete successfully - missing completion message")

print(f"✓ Training job '{job_name}' completed successfully")

In [None]:
for c in client.get_job(name=job_name).steps:
    print(f"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\n")

In [None]:
for logline in client.get_job_logs(job_name, follow=False):
    print(logline)

In [None]:
client.delete_job(job_name)