In [None]:
def train_fashion_mnist():
    import os

    import torch
    import torch.distributed as dist
    import torch.nn.functional as F
    from torch import nn
    from torch.utils.data import DataLoader, DistributedSampler, Dataset
    import numpy as np
    import struct

    # Define the PyTorch CNN model to be trained
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5, 1)
            self.conv2 = nn.Conv2d(20, 50, 5, 1)
            self.fc1 = nn.Linear(4 * 4 * 50, 500)
            self.fc2 = nn.Linear(500, 10)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, 2, 2)
            x = x.view(-1, 4 * 4 * 50)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return F.log_softmax(x, dim=1)

    # Force CPU-only for this test to avoid accidental NCCL/GPU usage
    backend = "gloo"
    device = torch.device("cpu")
    print(f"Using Device: cpu, Backend: {backend}")

    # Setup PyTorch distributed.
    local_rank = int(os.getenv("LOCAL_RANK") or os.getenv("PET_NODE_RANK") or 0)
    dist.init_process_group(backend=backend)
    print(
        "Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}".format(
            dist.get_world_size(),
            dist.get_rank(),
            local_rank,
        )
    )

    # Create the model and load it into the device.
    model = nn.parallel.DistributedDataParallel(Net().to(device))
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    # Prefer shared PVC if present; else fallback to internet download (rank 0 only)
    from urllib.parse import urlparse
    import gzip, shutil

    pvc_root = "/mnt/shared"
    pvc_raw = os.path.join(pvc_root, "FashionMNIST", "raw")

    use_pvc = os.path.isdir(pvc_raw) and any(os.scandir(pvc_raw))

    if not use_pvc:
        raise RuntimeError("Shared PVC not mounted or empty at /mnt/shared/FashionMNIST/raw; this test requires a pre-populated RWX PVC")

    print("Using dataset from shared PVC at /mnt/shared")

    def _read_idx_images(path):
        with open(path, "rb") as f:
            magic, num, rows, cols = struct.unpack(">IIII", f.read(16))
            if magic != 2051:
                raise RuntimeError(f"Unexpected images magic: {magic}")
            data = f.read()
        return np.frombuffer(data, dtype=np.uint8).reshape(num, rows, cols)

    def _read_idx_labels(path):
        with open(path, "rb") as f:
            magic, num = struct.unpack(">II", f.read(8))
            if magic != 2049:
                raise RuntimeError(f"Unexpected labels magic: {magic}")
            data = f.read()
        return np.frombuffer(data, dtype=np.uint8)

    class MnistIdxDataset(Dataset):
        def __init__(self, images_path: str, labels_path: str):
            self.images = _read_idx_images(images_path)
            self.labels = _read_idx_labels(labels_path)
            if len(self.images) != len(self.labels):
                raise RuntimeError("Images and labels count mismatch")
        def __len__(self):
            return len(self.labels)
        def __getitem__(self, idx: int):
            import torch as _torch
            img = _torch.from_numpy(self.images[idx][None, ...].astype("float32") / 255.0)
            label = int(self.labels[idx])
            return img, label

    images_path = os.path.join(pvc_root, "FashionMNIST", "raw", "train-images-idx3-ubyte")
    labels_path = os.path.join(pvc_root, "FashionMNIST", "raw", "train-labels-idx1-ubyte")

    dataset = MnistIdxDataset(images_path, labels_path)
    train_loader = DataLoader(
        dataset,
        batch_size=100,
        sampler=DistributedSampler(dataset)
    )

    dist.barrier()
    for epoch in range(1, 3):
        model.train()

        # Iterate over mini-batches from the training set
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            # Move the data to the selected device
            inputs, labels = inputs.to(device), labels.to(device)
            # Forward pass
            outputs = model(inputs)
            loss = F.nll_loss(outputs, labels)
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % 10 == 0 and dist.get_rank() == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(inputs),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
                )

    # Wait for the distributed training to complete
    dist.barrier()
    if dist.get_rank() == 0:
        print("Training is finished")

    # Finally clean up PyTorch distributed
    dist.destroy_process_group()

In [None]:
import os
import gzip
import shutil
import socket
import time

import boto3
from botocore.config import Config as BotoConfig
from botocore.exceptions import ClientError

# --- Global networking safety net: cap all socket operations ---
# Any blocking socket operation (like reading S3 object data) will raise
# socket.timeout after this many seconds instead of hanging forever.
socket.setdefaulttimeout(10)  # seconds

# Notebook's PVC mount path (per Notebook CR). Training pods will mount the same PVC at /mnt/shared
PVC_NOTEBOOK_PATH = "/opt/app-root/src"
DATASET_ROOT_NOTEBOOK = PVC_NOTEBOOK_PATH  # place FashionMNIST under this root
FASHION_RAW_DIR = os.path.join(DATASET_ROOT_NOTEBOOK, "FashionMNIST", "raw")
os.makedirs(FASHION_RAW_DIR, exist_ok=True)

# Env config for S3/MinIO
s3_endpoint = os.getenv("AWS_DEFAULT_ENDPOINT", "")
s3_access_key = os.getenv("AWS_ACCESS_KEY_ID", "")
s3_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY", "")
s3_bucket = os.getenv("AWS_STORAGE_BUCKET", "")
s3_prefix = os.getenv("AWS_STORAGE_BUCKET_MNIST_DIR", "")  # e.g. "data"

def stream_download(s3, bucket, key, dst):
    """
    Download an object from S3/MinIO using get_object and streaming reads.
    This bypasses boto3's TransferManager / download_file entirely.
    Returns True on success, False on any error (including timeouts).
    """
    print(f"[notebook] STREAM download s3://{bucket}/{key} -> {dst}")
    t0 = time.time()

    try:
        # Metadata / headers fetch — should be quick or fail clearly
        resp = s3.get_object(Bucket=bucket, Key=key)
    except ClientError as e:
        err = e.response.get("Error", {})
        print(f"[notebook] CLIENT ERROR (get_object) for {key}: {err}")
        return False
    except Exception as e:
        print(f"[notebook] OTHER ERROR (get_object) for {key}: {e}")
        return False

    body = resp["Body"]
    try:
        with open(dst, "wb") as f:
            while True:
                try:
                    # Each read is bounded by socket.setdefaulttimeout(...)
                    chunk = body.read(1024 * 1024)  # 1MB per chunk
                except socket.timeout as e:
                    print(f"[notebook] socket.timeout while reading {key}: {e}")
                    return False
                if not chunk:
                    break
                f.write(chunk)
    except Exception as e:
        print(f"[notebook] ERROR writing to {dst} for {key}: {e}")
        return False

    t1 = time.time()
    print(f"[notebook] DONE  stream {key} in {t1 - t0:.2f}s")
    return True


if s3_endpoint and s3_bucket:
    try:
        # Normalize endpoint URL
        endpoint_url = (
            s3_endpoint
            if s3_endpoint.startswith("http")
            else f"https://{s3_endpoint}"
        )
        prefix = (s3_prefix or "").strip("/")

        print(
            f"S3 configured (boto3, notebook): "
            f"endpoint={endpoint_url}, bucket={s3_bucket}, prefix={prefix or '<root>'}"
        )

        # Boto config: single attempt, reasonable connect/read timeouts.
        boto_cfg = BotoConfig(
            signature_version="s3v4",
            s3={"addressing_style": "path"},
            retries={"max_attempts": 1, "mode": "standard"},
            connect_timeout=5,
            read_timeout=10,
        )

        # Create S3/MinIO client
        s3 = boto3.client(
            "s3",
            endpoint_url=endpoint_url,
            aws_access_key_id=s3_access_key,
            aws_secret_access_key=s3_secret_key,
            config=boto_cfg,
            verify=False,
        )

        # Optional: quick debug HEAD of the problematic key
        # (will just log if there's an access or existence problem)
        test_key = "data/t10k-labels-idx1-ubyte.gz"
        try:
            print(f"[debug] HEAD s3://{s3_bucket}/{test_key}")
            meta = s3.head_object(Bucket=s3_bucket, Key=test_key)
            print(f"[debug] HEAD OK: size={meta.get('ContentLength')}")
        except ClientError as e:
            print(f"[debug] HEAD ERROR for {test_key}: {e.response.get('Error')}")

        # List and download all objects under the prefix
        paginator = s3.get_paginator("list_objects_v2")
        pulled_any = False

        for page in paginator.paginate(Bucket=s3_bucket, Prefix=prefix or ""):
            contents = page.get("Contents", [])
            if not contents:
                continue

            for obj in contents:
                key = obj["Key"]

                # Skip "directory markers"
                if key.endswith("/"):
                    continue

                # Determine relative path under prefix for local storage
                rel = key[len(prefix):].lstrip("/") if prefix else key
                dst = os.path.join(FASHION_RAW_DIR, rel)
                os.makedirs(os.path.dirname(dst), exist_ok=True)

                # Download only if missing
                if not os.path.exists(dst):
                    ok = stream_download(s3, s3_bucket, key, dst)
                    if not ok:
                        # Skip decompression and move on to next object
                        continue
                else:
                    print(f"[notebook] Skipping existing file {dst}")

                # If the file is .gz, decompress and remove the .gz
                if dst.endswith(".gz") and os.path.exists(dst):
                    out_path = os.path.splitext(dst)[0]
                    if not os.path.exists(out_path):
                        print(f"[notebook] Decompressing {dst} -> {out_path}")
                        try:
                            with gzip.open(dst, "rb") as f_in, open(out_path, "wb") as f_out:
                                shutil.copyfileobj(f_in, f_out)
                        except Exception as e:
                            print(f"[notebook] Failed to decompress {dst}: {e}")
                        else:
                            try:
                                os.remove(dst)
                            except Exception:
                                # Not critical if we can't delete the gzip
                                pass

                pulled_any = True

        print(f"[notebook] S3 pulled_any={pulled_any}")

    except Exception as e:
        print(f"[notebook] S3 fetch failed: {e}")
else:
    print("[notebook] S3 not configured: missing endpoint or bucket env vars")

# Check if we have data; if not, try downloading from internet
files_needed = [
    "train-images-idx3-ubyte",
    "train-labels-idx1-ubyte",
    "t10k-images-idx3-ubyte",
    "t10k-labels-idx1-ubyte",
]
files_present = all(os.path.exists(os.path.join(FASHION_RAW_DIR, f)) for f in files_needed)

if not files_present:
    print("[notebook] Dataset not complete, attempting internet download...")
    try:
        import urllib.request
        
        base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"
        gz_files = [
            "train-images-idx3-ubyte.gz",
            "train-labels-idx1-ubyte.gz",
            "t10k-images-idx3-ubyte.gz",
            "t10k-labels-idx1-ubyte.gz",
        ]
        
        for gz_file in gz_files:
            url = base_url + gz_file
            dst_gz = os.path.join(FASHION_RAW_DIR, gz_file)
            dst = os.path.splitext(dst_gz)[0]
            
            if os.path.exists(dst):
                print(f"[notebook] Already have {dst}, skipping download")
                continue
            
            print(f"[notebook] Downloading {url} ...")
            urllib.request.urlretrieve(url, dst_gz)
            
            print(f"[notebook] Decompressing {dst_gz} ...")
            with gzip.open(dst_gz, "rb") as f_in, open(dst, "wb") as f_out:
                shutil.copyfileobj(f_in, f_out)
            
            os.remove(dst_gz)
            print(f"[notebook] Done: {dst}")
        
        print("[notebook] Internet download completed successfully")
    except Exception as e:
        print(f"[notebook] Internet download failed: {e}")
        raise RuntimeError("Internet download failed; aborting test")
else:
    print("[notebook] Dataset files already present, skipping download")


In [None]:
# Init SDK client with user token/API URL (no Backend types import)
import os
from kubernetes import client as k8s_client
from kubeflow.trainer import TrainerClient
from kubeflow.common.types import KubernetesBackendConfig

openshift_api_url = os.getenv("OPENSHIFT_API_URL", "")
token = os.getenv("NOTEBOOK_TOKEN", "")

cfg = k8s_client.Configuration()
cfg.host = openshift_api_url
cfg.verify_ssl = False
cfg.api_key = {"authorization": f"Bearer {token}"}

api_client = k8s_client.ApiClient(cfg)

backend_cfg = KubernetesBackendConfig(
    client_configuration=api_client.configuration,
)

client = TrainerClient(backend_cfg)

In [None]:

try:
    torch_runtime = client.get_runtime("torch-distributed")
except Exception as e:
    raise RuntimeError("Runtime 'torch-distributed' not found or not accessible") from e

In [None]:
import os
from kubeflow.trainer import CustomTrainer
from kubeflow.trainer.options import PodTemplateOverrides, PodTemplateOverride, PodSpecOverride, ContainerOverride

pvc_name = os.getenv("SHARED_PVC_NAME", "")
print(f"[notebook] Using PVC: {pvc_name}")

job_name = client.train(
    trainer=CustomTrainer(
        func=train_fashion_mnist,
        num_nodes=2,
        resources_per_node={
            "cpu": 2,
            "memory": "8Gi",
        },
    ),
    runtime=torch_runtime,
    options=[
        PodTemplateOverrides(
            PodTemplateOverride(
                target_jobs=["node"],
                spec=PodSpecOverride(
                    volumes=[
                        {
                            "name": "work",
                            "persistentVolumeClaim": {"claimName": pvc_name},
                        }
                    ],
                    containers=[
                        ContainerOverride(
                            name="node",
                            volume_mounts=[
                                {"name": "work", "mountPath": "/mnt/shared", "readOnly": False}
                            ],
                        )
                    ],
                )
            )
        )
    ],
)

print(f"[notebook] Job submitted: {job_name}")   

In [None]:
# Wait for the running status, then wait for completion or failure
client.wait_for_job_status(name=job_name, status={"Running"}, timeout=300)
client.wait_for_job_status(name=job_name, status={"Complete", "Failed"}, timeout=900)

# Get job details and logs
job = client.get_job(name=job_name)
pod_logs = client.get_job_logs(name=job_name, follow=False)

# Collect all log lines from the generator into a list
logs = list(pod_logs)
log_text = "\n".join(str(line) for line in logs)


print(f"Training job final status: {job.status}")

# Check 1: Job status must not be "Failed"  
if job.status == "Failed":
    print(f"ERROR: Training job '{job_name}' has Failed status")
    print("Last 30 lines of logs:")
    for line in logs[-30:]:
        print(line)
    raise RuntimeError(f"Training job '{job_name}' failed")

# Check 2: Look for the training completion message in logs
# This is critical because the training script may catch exceptions and exit 0
if "Training is finished" not in log_text:
    print(f"ERROR: Training completion message not found in logs")
    print("Last 50 lines of logs:")
    for line in logs[-50:]:
        print(line)
    raise RuntimeError(f"Training did not complete successfully - missing completion message")

print(f"✓ Training job '{job_name}' completed successfully")

In [None]:
for c in client.get_job(name=job_name).steps:
    print(f"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\n")

In [None]:
for logline in client.get_job_logs(job_name, follow=True):
    print(logline)

In [None]:
client.delete_job(job_name)