In [1]:
import kfp
import kfp.dsl as dsl

# https://hub.docker.com/r/kubeflowkatib/pytorch-mnist-gpu/tags
# https://github.com/kubeflow/katib/blob/master/examples/v1beta1/trial-images/pytorch-mnist/Dockerfile.gpu
# nvcr.io/nvidia/pytorch:24.01-py3 <-- base image
# https://docs.nvidia.com/deeplearning/frameworks/index.html
 
# BASE_IMAGE = 'tensorflow/tensorflow:latest-gpu' # TensorFlow CANNOT see any GPUs. GPU acceleration is NOT possible.
# BASE_IMAGE = 'kubeflowkatib/pytorch-mnist-gpu' # does not finish?
# BASE_IMAGE = 'nvcr.io/nvidia/pytorch:25.01-py3'
BASE_IMAGE = 'nvcr.io/nvidia/pytorch:24.01-py3'
print(f"Using PyTorch Base Image: {BASE_IMAGE}")


Using PyTorch Base Image: nvcr.io/nvidia/pytorch:24.01-py3


In [2]:
@dsl.component(
    base_image=BASE_IMAGE,
)
def check_gpu_access_pytorch():
    """
    A simple component that checks and prints GPU availability using PyTorch.
    """
    import torch
    import logging
    import os # Import os inside the function scope as well

    logging.basicConfig(level=logging.INFO, format='%(levelname)s:%(asctime)s:%(message)s')

    # Get the actual image name from the environment variable set by KFP
    actual_base_image = os.getenv('KFP_COMPONENT_IMAGE')
    if actual_base_image:
        logging.info(f"Running inside Base Image: {actual_base_image}")
    else:
        # This case is less likely in KFP v2+, but good practice
        logging.warning("Could not determine base image from KFP_COMPONENT_IMAGE env variable.")

    try:
        logging.info(f"PyTorch version: {torch.__version__}")
        # Check if CUDA is available at all (drivers found, PyTorch CUDA compiled)
        if not torch.cuda.is_available():
            logging.warning("PyTorch CUDA is NOT available. GPU acceleration is NOT possible.")
            # You could add more diagnostics here if needed
        else:
            # Get the number of available GPUs
            gpu_count = torch.cuda.device_count()
            if gpu_count == 0:
                logging.warning("PyTorch CUDA is available, but no GPUs were found by PyTorch.")
            else:
                logging.info(f"PyTorch found {gpu_count} GPU(s).")
                # List details for each GPU
                for i in range(gpu_count):
                    gpu_name = torch.cuda.get_device_name(i)
                    gpu_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3) # GiB
                    logging.info(f"  GPU {i}: {gpu_name} - Memory: {gpu_memory:.2f} GiB")

                # Optional: Add a small test computation on the first GPU
                try:
                    device = torch.device(f'cuda:{0}') # Use first GPU
                    tensor_a = torch.randn(3, 3).to(device)
                    tensor_b = torch.randn(3, 3).to(device)
                    result = torch.matmul(tensor_a, tensor_b)
                    logging.info(f"Simple PyTorch computation on GPU {0} successful: First element={result[0,0].item()}")
                except Exception as compute_e:
                     logging.error(f"PyTorch detected GPU but failed test computation: {compute_e}", exc_info=True)

    except Exception as e:
        logging.error(f"An error occurred while checking for GPUs with PyTorch: {e}", exc_info=True) # Log traceback
        # Re-raise the exception to make the KFP step fail
        raise e

In [3]:
# --- Example Pipeline Definition (modify as needed) ---
@dsl.pipeline(
    name='gpu-test-pipeline-pt',
    description='A minimal pipeline to test GPU access using PyTorch.'
)
def pytorch_gpu_test_pipeline():
    """Defines the pipeline structure."""
    check_gpu_task = (
        check_gpu_access_pytorch() # Call the PyTorch version
        .set_display_name("Check PyTorch GPU Availability")
        .set_gpu_limit(1)
        # Ensure this label exists on your GPU nodes
        .add_node_selector_constraint('nvidia.com/gpu')
        .set_memory_limit("4G")
        .set_memory_request("2G")
        .set_cpu_limit("1")
        .set_cpu_request("0.5")
    )

In [4]:
client = kfp.Client()
client.create_run_from_pipeline_func(
    pytorch_gpu_test_pipeline,
    experiment_name="mnist_pipeline",
)



RunPipelineResult(run_id=4148aa0d-45c8-4835-b90e-890bc7821458)