In [None]:
%pip install datasets --quiet
# pip Install kubeflow SDK from main branch for testing
%pip install git+https://github.com/opendatahub-io/kubeflow-sdk.git@main --quiet

In [None]:
# Standard library imports
import logging
import os
import sys
import time
from io import StringIO

In [None]:
from dotenv import load_dotenv
load_dotenv()

In [None]:
from kubernetes import client as k8s, config as k8s_config
# Edit to match your specific settings
api_server = os.getenv("OPENSHIFT_API_URL")
token = os.getenv("NOTEBOOK_USER_TOKEN")
if not api_server or not token:
    raise RuntimeError("OPENSHIFT_API_URL and NOTEBOOK_USER_TOKEN environment variables are required")
PVC_NAME = os.getenv("SHARED_PVC_NAME", "")

if not PVC_NAME:
   raise RuntimeError("SHARED_PVC_NAME environment variable is required")

configuration = k8s.Configuration()
configuration.host = api_server
# Un-comment if your cluster API server uses a self-signed certificate or an un-trusted CA
configuration.verify_ssl = False
configuration.api_key = {"authorization": f"Bearer {token}"}
api_client = k8s.ApiClient(configuration)

PVC_MOUNT_PATH = "/opt/app-root/src"

In [None]:
# Converting the format of the intial messages.
def convert_to_messages(example):
    """
    Convert a sql-create-context example to chat template format.
    
    The user provides the database schema and question.
    The assistant responds with the SQL query.
    """
    user_message = f"""Given the following database schema:

{example['context']}

Write a SQL query to answer this question: {example['question']}"""
    
    assistant_message = example['answer']
    
    return {
        "messages": [
            {"role": "user", "content": user_message},
            {"role": "assistant", "content": assistant_message}
        ]
    }

In [None]:
import os
import gzip
import shutil
import socket

import boto3
from botocore.config import Config as BotoConfig
from botocore.exceptions import ClientError

# --- Global networking safety net: cap all socket operations ---
socket.setdefaulttimeout(10)  # seconds

# Notebook's PVC mount path (per Notebook CR). Training pods will mount the same PVC at /opt/app-root/src
PVC_NOTEBOOK_PATH = "/opt/app-root/src/"
DATASET_ROOT_NOTEBOOK = PVC_NOTEBOOK_PATH
TXT_SQL_DIR = os.path.join(DATASET_ROOT_NOTEBOOK, "txt-sql-data", "train")
MODEL_DIR = os.path.join(DATASET_ROOT_NOTEBOOK, "Qwen", "Qwen2.5-1.5B-Instruct")
os.makedirs(TXT_SQL_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

# Env config for S3/MinIO
s3_endpoint = os.getenv("AWS_DEFAULT_ENDPOINT", "")
s3_access_key = os.getenv("AWS_ACCESS_KEY_ID", "")
s3_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY", "")
s3_bucket = os.getenv("AWS_STORAGE_BUCKET", "")
s3_prefix = os.getenv("AWS_STORAGE_BUCKET_LORA_DIR", "")  

data_download_successful = False

def stream_download(s3, bucket, key, dst):
    """
    Download an object from S3/MinIO using get_object and streaming reads.
    Returns True on success, False on any error.
    """
    print(f"[notebook] STREAM download s3://{bucket}/{key} -> {dst}")
    t0 = time.time()

    try:
        resp = s3.get_object(Bucket=bucket, Key=key)
    except ClientError as e:
        err = e.response.get("Error", {})
        print(f"[notebook] CLIENT ERROR (get_object) for {key}: {err}")
        return False
    except Exception as e:
        print(f"[notebook] OTHER ERROR (get_object) for {key}: {e}")
        return False

    body = resp["Body"]
    try:
        with open(dst, "wb") as f:
            while True:
                try:
                    chunk = body.read(1024 * 1024)  # 1MB per chunk
                except socket.timeout as e:
                    print(f"[notebook] socket.timeout while reading {key}: {e}")
                    return False
                if not chunk:
                    break
                f.write(chunk)
    except Exception as e:
        print(f"[notebook] ERROR writing to {dst} for {key}: {e}")
        return False

    t1 = time.time()
    print(f"[notebook] DONE  stream {key} in {t1 - t0:.2f}s")
    return True


if s3_endpoint and s3_bucket:
    try:
        # Normalize endpoint URL
        endpoint_url = (
            s3_endpoint
            if s3_endpoint.startswith("http")
            else f"https://{s3_endpoint}"
        )
        prefix = (s3_prefix or "").strip("/")

        print(
            f"[notebook] S3 configured: "
            f"endpoint={endpoint_url}, bucket={s3_bucket}, prefix={prefix or '<root>'}"
        )

        # Boto config: single attempt, reasonable connect/read timeouts
        boto_cfg = BotoConfig(
            signature_version="s3v4",
            s3={"addressing_style": "path"},
            retries={"max_attempts": 1, "mode": "standard"},
            connect_timeout=5,
            read_timeout=10,
        )

        # Create S3/MinIO client
        s3 = boto3.client(
            "s3",
            endpoint_url=endpoint_url,
            aws_access_key_id=s3_access_key,
            aws_secret_access_key=s3_secret_key,
            config=boto_cfg,
            verify=False,
        )

        # List and download all objects under the prefix
        paginator = s3.get_paginator("list_objects_v2")
        pulled_any = False
        file_count = 0
        
        print(f"[notebook] Starting S3 download from prefix: {prefix}")
        for page in paginator.paginate(Bucket=s3_bucket, Prefix=prefix or ""):
            contents = page.get("Contents", [])
            if not contents:
                print(f"[notebook] No contents found in this page")
                continue
            
            print(f"[notebook] Found {len(contents)} objects in this page")

            for obj in contents:
                key = obj["Key"]
                file_count += 1

                # Skip "directory markers"
                if key.endswith("/"):
                    print(f"[notebook] Skipping directory marker: {key}")
                    continue

                # Determine relative path under prefix for local storage
                rel = key[len(prefix):].lstrip("/") if prefix else key
                print(f"[notebook] Processing key={key}, rel={rel}")
                
                # Route to appropriate directory based on content type
                if "table-gpt" in rel.lower() or rel.endswith(".jsonl"):
                    dst = os.path.join(TXT_SQL_DIR, os.path.basename(rel))
                    print(f"[notebook] Routing to dataset dir: {dst}")
                elif "qwen" in rel.lower() or any(rel.endswith(ext) for ext in [".bin", ".json", ".model", ".safetensors", ".txt"]):
                    # Preserve directory structure for model files
                    dst = os.path.join(MODEL_DIR, rel.split("Qwen2.5-1.5B-Instruct/")[-1] if "Qwen2.5-1.5B-Instruct" in rel else os.path.basename(rel))
                    print(f"[notebook] Routing to model dir: {dst}")
                else:
                    # Default: use the relative path as-is
                    dst = os.path.join(DATASET_ROOT_NOTEBOOK, rel)
                    print(f"[notebook] Routing to default dir: {dst}")
                
                os.makedirs(os.path.dirname(dst), exist_ok=True)

                # Download only if missing
                if not os.path.exists(dst):
                    ok = stream_download(s3, s3_bucket, key, dst)
                    if not ok:
                        print(f"[notebook] Download failed for {key}")
                        continue
                    pulled_any = True
                else:
                    print(f"[notebook] Skipping existing file {dst}")
                    pulled_any = True

                # If the file is .gz, decompress and remove the .gz
                if dst.endswith(".gz") and os.path.exists(dst):
                    out_path = os.path.splitext(dst)[0]
                    if not os.path.exists(out_path):
                        print(f"[notebook] Decompressing {dst} -> {out_path}")
                        try:
                            with gzip.open(dst, "rb") as f_in, open(out_path, "wb") as f_out:
                                shutil.copyfileobj(f_in, f_out)
                        except Exception as e:
                            print(f"[notebook] Failed to decompress {dst}: {e}")
                        else:
                            try:
                                os.remove(dst)
                            except Exception:
                                pass

        if pulled_any:
            print(f"[notebook] ✓ S3 download successful. Processed {file_count} files")
            data_download_successful = True
        else:
            print(f"[notebook] ✗ S3 download found no files to download")

    except Exception as e:
        print(f"[notebook] ✗ S3 fetch failed: {e}")
        import traceback
        traceback.print_exc()
        print("[notebook] Will attempt HuggingFace fallback...")
else:
    print("[notebook] S3 not configured (missing endpoint or bucket env vars)")

# Fallback to HuggingFace if S3 was not configured or failed (requires internet)
if not data_download_successful:
    print("[notebook] Attempting HuggingFace dataset download (requires internet)...")
    try:
        import json
        import random
        from datasets import load_dataset

        # Load the Table-GPT dataset
        print("[notebook] Loading Table-GPT dataset from HuggingFace...")
        # Load the dataset
        dataset = load_dataset("b-mc2/sql-create-context", split="train")

        TRAIN_SIZE = 100  # Adjust based on your time/compute budget

        # Shuffle and select a subset
        train_dataset = dataset.shuffle(seed=42).select(range(min(TRAIN_SIZE, len(dataset))))

        # Convert to messages format
        train_data = [convert_to_messages(example) for example in train_dataset]

        # Save the subset to a JSONL file
        output_file = os.path.join(TXT_SQL_DIR, "train_All_100.jsonl")
        with open(output_file, "w") as f:
            for example in train_data:
                f.write(json.dumps(example) + "\n")

        print(f"[notebook] ✓ HuggingFace download successful. Subset saved to {output_file}")
        data_download_successful = True

    except Exception as hf_error:
        print(f"[notebook] ✗ HuggingFace download failed: {hf_error}")
        import traceback
        traceback.print_exc()
        raise RuntimeError(
            "Failed to download dataset from both S3 and HuggingFace. "
            "In disconnected environments, ensure S3/MinIO is configured with the required data. "
            "In connected environments, check your internet connection and credentials."
        ) from hf_error

# Verify dataset file exists
dataset_file = os.path.join(TXT_SQL_DIR, "train_All_100.jsonl")
if os.path.exists(dataset_file):
    print(f"[notebook] ✓ Dataset ready: {dataset_file}")
else:
    raise RuntimeError(f"Dataset file not found: {dataset_file}")

# Verify model directory has files (model will be downloaded during training if not present)
if os.path.exists(MODEL_DIR) and os.listdir(MODEL_DIR):
    print(f"[notebook] ✓ Model files ready in: {MODEL_DIR}")
    print(f"[notebook] Model files: {os.listdir(MODEL_DIR)[:5]}...")  # Show first 5 files
else:
    print(f"[notebook] Note: Model directory is empty: {MODEL_DIR}")
    print("[notebook] Training will download model from HuggingFace during execution")

In [None]:
# Model download - use S3 if available, otherwise HuggingFace
if os.path.exists(MODEL_DIR) and os.listdir(MODEL_DIR):
    model_path = MODEL_DIR
    print(f"✓ Using local model from S3: {model_path}")
else:
    # Download from HuggingFace
    print("[notebook] Model not found in S3, downloading from HuggingFace...")
    from huggingface_hub import snapshot_download
    
    token = os.getenv("HUGGINGFACE_HUB_TOKEN")
    model_path = snapshot_download(
        repo_id="Qwen/Qwen2.5-1.5B-Instruct",
        local_dir=MODEL_DIR,
        token=token,
        resume_download=True,
        local_dir_use_symlinks=False,
    )
    print(f"✓ Model downloaded to: {model_path}")

In [None]:
# Training configuration
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"

# You can also try these alternatives:
# MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"    # Larger, more capable
# MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"  # Smaller, faster training
# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"  # Alternative architecture

# LoRA configuration
LORA_R = 16          # Rank - start small, increase if needed
LORA_ALPHA = 32      # Alpha - typically 2x rank
LORA_DROPOUT = 0.0   # Dropout - 0.0 is optimized for Unsloth

# Training configuration
NUM_EPOCHS = 2            # More epochs = better learning, longer training
LEARNING_RATE = 2e-4        # Standard LoRA learning rate
MAX_SEQ_LEN = 1024          # Maximum sequence length
MICRO_BATCH_SIZE = 16       # Batch size per GPU (reduce if OOM)
GRADIENT_ACCUMULATION = 4   # Effective batch = micro_batch * grad_accum

# QLoRA settings (set to True to enable 4-bit quantization)
USE_QLORA = True  # Set to True if you have limited GPU memory

params = {
        # Model and data path
        'model_path': MODEL_NAME,
        'data_path': "/opt/app-root/src/txt-sql-data/train/train_All_100.jsonl",
        'ckpt_output_dir': "/opt/app-root/src/checkpoints-logs-dir",
        'data_output_path': "/opt/app-root/src/lora-json/_data",
        # Important for LORA
        'lr_scheduler': "cosine",
        'warmup_steps': 0,
        'seed': 42,
        # LoRA configuration
        'lora_r': LORA_R,
        'lora_alpha': LORA_ALPHA,
        'lora_dropout': LORA_DROPOUT,

        # Training configuration
        'num_epochs': NUM_EPOCHS,
        'learning_rate': LEARNING_RATE,
        'micro_batch_size': MICRO_BATCH_SIZE,
        'max_seq_len': MAX_SEQ_LEN,
        'gradient_accumulation_steps': GRADIENT_ACCUMULATION,

        # Dataset format
        'dataset_type' : "chat_template",
        'field_messages' : "messages",
        # Quantization
        'load_in_4bit': USE_QLORA,
        #GPU configuration
        'nproc_per_node' : 2,
        'nnodes' : 2,
        # Logging
        'logging_steps': 10,
        'save_steps': 200,
        'save_total_limit': 3,

        # Model Checkpointing
        'save_final_checkpoint': True,
        'checkpoint_at_epoch': 2,
}

In [None]:
from kubeflow.trainer import TrainerClient
from kubeflow.trainer.rhai import TrainingHubAlgorithms
from kubeflow.trainer.rhai import TrainingHubTrainer
from kubeflow.common.types import KubernetesBackendConfig

backend_cfg = KubernetesBackendConfig(client_configuration=api_client.configuration)
client = TrainerClient(backend_cfg)

In [None]:
training_runtime_name = os.getenv("TRAINING_RUNTIME")

if not training_runtime_name:
    raise RuntimeError("TRAINING_RUNTIME environment variable is required")

th_runtime = None
for runtime in client.list_runtimes():
    if runtime.name == training_runtime_name:
        th_runtime = runtime
        print("Found runtime: " + str(th_runtime))
        break

if th_runtime is None:
    raise RuntimeError(f"Required runtime '{training_runtime_name}' not found")

In [None]:
from kubeflow.trainer.options.kubernetes import (
    PodTemplateOverrides,
    PodTemplateOverride,
    PodSpecOverride,
    ContainerOverride,
)

cache_root = "/opt/app-root/src/.cache/huggingface"
triton_cache = "/opt/app-root/src/.triton"

job_name = client.train(
    trainer=TrainingHubTrainer(
        algorithm=TrainingHubAlgorithms.LORA_SFT,
        func_args=params,
        env={
            "HF_HOME": cache_root,
            "TRITON_CACHE_DIR": triton_cache,
            "XDG_CACHE_HOME": "/opt/app-root/src/.cache",
            "NCCL_DEBUG": "INFO",
        },
        resources_per_node={
            "cpu": 4,
            "memory": "32Gi",
            "nvidia.com/gpu": 1
        },
    ),
    options=[
        PodTemplateOverrides(
            PodTemplateOverride(
                target_jobs=["node"],
                spec=PodSpecOverride(
                    volumes=[
                        {"name": "work", "persistentVolumeClaim": {"claimName": PVC_NAME}},
                    ],
                    containers=[
                        ContainerOverride(
                            name="node",
                            volume_mounts=[
                                {"name": "work", "mountPath": "/opt/app-root/src", "readOnly": False},
                            ],
                        )
                    ],
                ),
            )
        )
    ],
    runtime=th_runtime,
)

In [None]:
# Wait for the running status, then wait for completion or failure
# Using reasonable timeout for LORA training
client.wait_for_job_status(name=job_name, status={"Running"}, timeout=300)
client.wait_for_job_status(name=job_name, status={"Complete", "Failed"}, timeout=1800)  # 30 minutes for training

# Get job details and logs
job = client.get_job(name=job_name)
pod_logs = client.get_job_logs(name=job_name, follow=False)

# Collect all log lines from the generator into a list
logs = list(pod_logs)
log_text = "\n".join(str(line) for line in logs)

print(f"Training job final status: {job.status}")

# Check 1: Job status must not be "Failed"  
if job.status == "Failed":
    print(f"ERROR: Training job '{job_name}' has Failed status")
    print("Last 30 lines of logs:")
    for line in logs[-30:]:
        print(line)
    raise RuntimeError(f"Training job '{job_name}' failed")

# Check 2: Look for the training completion message in logs
# This is critical because the training script may catch exceptions and exit 0
if "[PY] LORA_SFT training complete. Result=" not in log_text:
    print(f"ERROR: Training completion message not found in logs")
    print("Last 50 lines of logs:")
    for line in logs[-50:]:
        print(line)
    raise RuntimeError(f"Training did not complete successfully - missing completion message")

print(f"✓ Training job '{job_name}' completed successfully")

In [None]:
for c in client.get_job(name=job_name).steps:
    print(f"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\n")

In [None]:
logs = client.get_job_logs(name=job_name, follow=False)

# Collect all log lines from the generator into a list
logs = list(pod_logs)
log_text = "\n".join(str(line) for line in logs)
print(log_text)

In [None]:
client.delete_job(job_name)