In [1]:
%%writefile ../../.devcontainer/.env.template
ENV_NAME=docker_dev_template

# GPU Configuration for RTX 4090
CUDA_TAG=12.4.0
PYTHON_VER=3.10

# Host Port Configuration
HOST_JUPYTER_PORT=8891
HOST_TENSORBOARD_PORT=6008
HOST_EXPLAINER_PORT=8050
HOST_STREAMLIT_PORT=8501
HOST_MLFLOW_PORT=5000

# JAX/GPU Configuration - CRITICAL: NO INLINE COMMENTS
# These environment variables are parsed directly by JAX and must be clean

# Memory fraction for GPU allocation (0.0 to 1.0)
# For RTX 4090 24GB VRAM, 0.4 provides good balance
XLA_PYTHON_CLIENT_MEM_FRACTION=0.4

# Disable memory preallocation for better memory management
XLA_PYTHON_CLIENT_PREALLOCATE=false

# Use platform allocator for optimal GPU memory handling
XLA_PYTHON_CLIENT_ALLOCATOR=platform

# XLA compiler flags for CUDA
XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda

# JAX memory preallocation limit in bytes
# 16GB limit (17179869184 bytes) for RTX 4090
JAX_PREALLOCATION_SIZE_LIMIT_BYTES=17179869184

# JAX behavior configuration
JAX_DISABLE_JIT=false
JAX_ENABLE_X64=false

# TensorFlow GPU configuration (if using TensorFlow)
TF_FORCE_GPU_ALLOW_GROWTH=true


Overwriting ../../.devcontainer/.env.template


In [2]:
%%writefile ../../.devcontainer/.dockerignore
# Reduce Docker build context
.git
.gitignore
.gitattributes
.gitmodules
.vscode
.idea
*.swp
*.swo
*~
.DS_Store
Thumbs.db
__pycache__
*.pyc
*.pyo
*.pyd
.Python
*.so
.coverage*
.cache
.pytest_cache
.mypy_cache
.tox
pip-log.txt
pip-delete-this-directory.txt
env
venv
ENV
env.bak
venv.bak
.ipynb_checkpoints
# Large data (adjust as needed)
data/raw
data/external
*.csv
*.parquet
*.h5
*.hdf5
# Models
*.pt
*.pth
*.pkl
*.joblib
models/
# Logs and temps
*.log
logs/
*.tmp
*.temp
.tmp
temp/
# Build artifacts
build/
dist/
*.egg-info/
.eggs/
# Node
node_modules
npm-debug.log*
yarn-*.log*
.npm
.eslintcache
.node_repl_history
*.tgz
*.tar.gz
# Archives
*.zip
*.tar
*.tar.bz2
*.rar
*.7z
# Docs (opt‑in if needed)
docs/
*.md
README*
LICENSE*
CHANGELOG*
# Tests (opt‑in if needed)
tests/
test_*
*_test.py
# CI
.github/
.gitlab-ci.yml
.travis.yml
.circleci/
azure-pipelines.yml
# Env
.env
.env.local
.env.*.local
.editorconfig
.prettierrc*
.eslintrc*
# Universal junk (de‑duped)
*.py[cod]

Overwriting ../../.devcontainer/.dockerignore


In [3]:
%%writefile ../../.devcontainer/devcontainer.json
{
  "name": "docker_dev_template_rtx4090",
  "dockerComposeFile": "../docker-compose.yml",
  "service": "datascience",
  "workspaceFolder": "/workspace",
  "shutdownAction": "stopCompose",

  "overrideCommand": false,
  "containerEnv": {
    "CONTAINER_WORKSPACE_FOLDER": "/workspace",
    "UV_PROJECT_ENVIRONMENT": "/app/.venv",
    "VIRTUAL_ENV": "/app/.venv",
    "PYTHONPATH": "/workspace",
    "TERM": "xterm-256color"
  },

  "runArgs": [
    "--gpus", "all",
    "--name", "${localEnv:ENV_NAME:docker_dev_template}_datascience"
  ],

  "customizations": {
    "vscode": {
      "settings": {
        "python.defaultInterpreterPath": "/app/.venv/bin/python",
        "python.pythonPath": "/app/.venv/bin/python",
        "python.terminal.activateEnvironment": true,
        "python.terminal.activateEnvInCurrentTerminal": true,
        "terminal.integrated.defaultProfile.linux": "bash",
        "terminal.integrated.profiles.linux": {
          "bash": {
            "path": "/bin/bash",
            "args": ["-l"],
            "env": {
              "VIRTUAL_ENV": "/app/.venv",
              "PATH": "/app/.venv/bin:${env:PATH}",
              "UV_PROJECT_ENVIRONMENT": "/app/.venv",
              "PYTHONPATH": "/workspace"
            }
          }
        },
        "jupyter.notebookFileRoot": "/workspace",
        "jupyter.kernels.filter": [
          {
            "path": "/app/.venv/bin/python",
            "type": "pythonEnvironment"
          }
        ]
      },
      "extensions": [
        "ms-python.python",
        "ms-toolsai.jupyter",
        "ms-azuretools.vscode-docker",
        "ms-python.flake8",
        "ms-python.black-formatter"
      ]
    }
  },

  "onCreateCommand": [
    "bash", "-lc",
    "echo 'onCreate: validating environment'; ls -la /app/.venv/bin/; which python || echo 'python not found in PATH'"
  ],

  "postCreateCommand": [
    "bash", "-lc",
    "set -e; source /app/.venv/bin/activate; python -c 'import sys; print(f\"python: {sys.executable}\")'; uv pip install -U ipykernel jupyter-client -q; python -m ipykernel install --user --name='uv_docker_dev_template' --display-name='Python (UV Environment)'; jupyter kernelspec list; python /app/tests/test_summary.py"
  ],

  "postStartCommand": [
    "bash", "-lc",
    "source /app/.venv/bin/activate; python --version; python -c 'import torch; print(f\"pytorch cuda: {torch.cuda.is_available()}\")' || echo 'pytorch test failed'; python /app/validate_gpu.py --quick || echo 'gpu validation completed with warnings'"
  ],

  "features": {},
  "forwardPorts": [8888, 6008, 8050, 8501, 5000],
  "portsAttributes": {
    "8888": { "label": "Jupyter Lab", "onAutoForward": "notify" },
    "6008": { "label": "TensorBoard", "onAutoForward": "silent" },
    "8050": { "label": "Explainer Dashboard", "onAutoForward": "silent" },
    "8501": { "label": "Streamlit", "onAutoForward": "silent" },
    "5000": { "label": "MLflow", "onAutoForward": "silent" }
  },

  "mounts": [
    "source=docker_dev_template_uv_cache,target=/root/.cache/uv,type=volume"
  ]
}

Overwriting ../../.devcontainer/devcontainer.json


In [4]:
%%writefile ../../.devcontainer/Dockerfile
# Dockerfile: RTX 4090 devcontainer with UV, JAX, and PyTorch (CUDA 12.x)

ARG CUDA_TAG=12.4.0
FROM nvidia/cuda:${CUDA_TAG}-devel-ubuntu22.04

ARG PYTHON_VER=3.10
ARG ENV_NAME=docker_dev_template
ENV DEBIAN_FRONTEND=noninteractive

# System dependencies
RUN --mount=type=cache,id=apt-cache,target=/var/cache/apt,sharing=locked \
    --mount=type=cache,id=apt-lists,target=/var/lib/apt/lists,sharing=locked \
    apt-get update && apt-get install -y --no-install-recommends \
        bash curl ca-certificates git procps htop \
        python3 python3-venv python3-pip python3-dev \
        build-essential cmake pkg-config \
        libjemalloc2 libjemalloc-dev \
        iproute2 net-tools lsof wget \
    && apt-get clean && rm -rf /var/lib/apt/lists/*

# UV package manager
COPY --from=ghcr.io/astral-sh/uv:0.7.12 /uv /uvx /bin/

WORKDIR /app

# Create venv managed by UV
RUN uv venv .venv --python "${PYTHON_VER}" --prompt "${ENV_NAME}"

ENV VIRTUAL_ENV=/app/.venv \
    PATH="/app/.venv/bin:${PATH}" \
    UV_PROJECT_ENVIRONMENT=/app/.venv \
    PYTHONPATH="/workspace"

# Memory and allocator settings
ENV LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libjemalloc.so.2 \
    MALLOC_ARENA_MAX=2 \
    MALLOC_TCACHE_MAX=0 \
    PYTORCH_NO_CUDA_MEMORY_CACHING=1

# GPU‑relevant environment
ENV XLA_PYTHON_CLIENT_PREALLOCATE=false \
    XLA_PYTHON_CLIENT_MEM_FRACTION=0.4 \
    XLA_PYTHON_CLIENT_ALLOCATOR=platform \
    PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:1024,expandable_segments:True \
    JAX_PREALLOCATION_SIZE_LIMIT_BYTES=17179869184

# Bring in project descriptors and tests
COPY pyproject.toml /workspace/
COPY uv.lock* /workspace/
COPY .devcontainer/validate_gpu.py /app/validate_gpu.py
COPY .devcontainer/tests/ /app/tests/

# Resolve project dependencies with UV
RUN --mount=type=cache,target=/root/.cache/uv,sharing=locked \
    cd /workspace && (uv sync --frozen --no-dev || (uv sync --no-dev && uv lock))

# CRITICAL FIX 2: Install PyTorch first to establish CUDA environment
RUN --mount=type=cache,target=/root/.cache/uv,sharing=locked \
    echo "Installing PyTorch with CUDA 12.4..." && \
    uv pip install --no-cache-dir torch torchvision torchaudio \
        --index-url https://download.pytorch.org/whl/cu124

# CRITICAL FIX 3: Install compatible CuDNN 9.8.0 to satisfy JAX requirements
RUN --mount=type=cache,target=/root/.cache/uv,sharing=locked \
    echo "Upgrading CuDNN to 9.8.0 for JAX compatibility..." && \
    uv pip install --no-cache-dir --upgrade nvidia-cudnn-cu12==9.8.0.69 || \
    uv pip install --no-cache-dir --upgrade nvidia-cudnn-cu12>=9.8.0

# CRITICAL FIX 4: Install JAX after CuDNN upgrade with proper dependency resolution
RUN --mount=type=cache,target=/root/.cache/uv,sharing=locked \
    echo "Removing any existing JAX installations..." && \
    (uv pip uninstall jax jaxlib jax-cuda12-plugin jax-cuda12-pjrt || true) && \
    echo "Installing JAX with CUDA 12 support..." && \
    (uv pip install --no-cache-dir "jax[cuda12-local]>=0.4.26" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
     || uv pip install --no-cache-dir "jax[cpu]>=0.4.26")

# Jupyter kernel support
RUN --mount=type=cache,target=/root/.cache/uv,sharing=locked \
    uv pip install ipykernel jupyter-client jupyterlab

# CUDA libs in path - include both system and package CUDA libraries
ENV LD_LIBRARY_PATH="/app/.venv/lib:/app/.venv/lib/python3.10/site-packages/nvidia/cudnn/lib:/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"

# Shell activation helper with updated library paths
RUN echo '#!/bin/bash' > /app/activate_uv.sh && \
    echo 'export VIRTUAL_ENV="/app/.venv"' >> /app/activate_uv.sh && \
    echo 'export PATH="/app/.venv/bin:$PATH"' >> /app/activate_uv.sh && \
    echo 'export UV_PROJECT_ENVIRONMENT="/app/.venv"' >> /app/activate_uv.sh && \
    echo 'export PYTHONPATH="/workspace:$PYTHONPATH"' >> /app/activate_uv.sh && \
    echo 'export LD_LIBRARY_PATH="/app/.venv/lib:/app/.venv/lib/python3.10/site-packages/nvidia/cudnn/lib:/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"' >> /app/activate_uv.sh && \
    echo 'cd /workspace' >> /app/activate_uv.sh && \
    chmod +x /app/activate_uv.sh && \
    echo 'source /app/activate_uv.sh' > /etc/profile.d/10-uv-activate.sh && \
    echo 'source /app/activate_uv.sh' >> /root/.bashrc && \
    chmod +x /etc/profile.d/10-uv-activate.sh

# Enhanced healthcheck script with CuDNN diagnostics
RUN echo '#!/bin/bash' > /app/healthcheck.sh && \
    echo 'source /app/.venv/bin/activate' >> /app/healthcheck.sh && \
    echo 'echo "=== CuDNN Version Check ==="' >> /app/healthcheck.sh && \
    echo 'python -c "import torch; print(f\"PyTorch CuDNN: {torch.backends.cudnn.version()}\")" || echo "PyTorch CuDNN check failed"' >> /app/healthcheck.sh && \
    echo 'echo "=== JAX Device Check ==="' >> /app/healthcheck.sh && \
    echo 'python -c "import jax; print(f\"JAX devices: {jax.devices()}\")" || echo "JAX device check failed"' >> /app/healthcheck.sh && \
    echo 'echo "=== GPU Validation ==="' >> /app/healthcheck.sh && \
    echo 'python /app/validate_gpu.py --quick' >> /app/healthcheck.sh && \
    chmod +x /app/healthcheck.sh

WORKDIR /workspace
CMD ["bash", "-l"]


Overwriting ../../.devcontainer/Dockerfile


In [5]:
%%writefile ../../docker-compose.yml
name: ${ENV_NAME:-docker_dev_template}

services:
  datascience:
    build:
      context: .
      dockerfile: .devcontainer/Dockerfile
      args:
        CUDA_TAG: ${CUDA_TAG:-12.4.0}
        PYTHON_VER: ${PYTHON_VER:-3.10}
        ENV_NAME: ${ENV_NAME:-docker_dev_template}
      cache_from:
        - nvidia/cuda:${CUDA_TAG:-12.4.0}-devel-ubuntu22.04

    container_name: ${ENV_NAME:-docker_dev_template}_datascience

    # UPDATED: Reference environment template from .devcontainer folder
    env_file:
      - .devcontainer/.env.template

    restart: unless-stopped
    depends_on:
      mlflow:
        condition: service_healthy

    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              count: all
              capabilities: [gpu]

    init: true
    gpus: all
    shm_size: 8g
    ulimits:
      memlock: -1
      stack: 67108864

    environment:
      - PYTHON_VER=${PYTHON_VER:-3.10}
      - UV_PROJECT_ENVIRONMENT=/app/.venv
      - VIRTUAL_ENV=/app/.venv
      - PYTHONPATH=/workspace
      - NVIDIA_VISIBLE_DEVICES=all
      - NVIDIA_DRIVER_CAPABILITIES=compute,utility
      - CUDA_VISIBLE_DEVICES=0
      - LD_LIBRARY_PATH=/app/.venv/lib:/usr/local/cuda/lib64
      - LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libjemalloc.so.2
      - MALLOC_ARENA_MAX=2
      - MALLOC_TCACHE_MAX=0
      - PYTORCH_NO_CUDA_MEMORY_CACHING=1
      
      # CRITICAL FIX: Removed inline comments from JAX environment variables
      # These were causing "could not convert string to float" errors
      - XLA_PYTHON_CLIENT_PREALLOCATE=false
      - XLA_PYTHON_CLIENT_ALLOCATOR=platform
      - XLA_PYTHON_CLIENT_MEM_FRACTION=0.4
      - XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/local/cuda
      - JAX_PREALLOCATION_SIZE_LIMIT_BYTES=17179869184
      - PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:1024,expandable_segments:True
      - JUPYTER_TOKEN=${JUPYTER_TOKEN:-jupyter}

    volumes:
      - .:/workspace:delegated
      - ./mlruns:/workspace/mlruns
      - uv-cache:/root/.cache/uv

    ports:
      - "${HOST_JUPYTER_PORT:-8891}:8888"
      - "${HOST_TENSORBOARD_PORT:-6008}:6008"
      - "${HOST_EXPLAINER_PORT:-8050}:8050"
      - "${HOST_STREAMLIT_PORT:-8501}:8501"

    command: >
      bash -lc '
        echo "[boot] Starting container: ${ENV_NAME:-docker_dev_template}";
        echo "[boot] Activating uv environment...";
        source /app/.venv/bin/activate;
        echo "[boot] Environment activated - Python: $(which python)";
        echo "[boot] UV available: $(uv --version)";
        echo "[boot] Running GPU validation...";
        python /app/validate_gpu.py || echo "GPU validation warning - check logs";
        echo "[boot] Starting Jupyter Lab on port 8888...";
        jupyter lab --ip=0.0.0.0 --port=8888 --allow-root 
        --NotebookApp.token="${JUPYTER_TOKEN}" 
        --NotebookApp.allow_origin="*" 
        --NotebookApp.open_browser=false
      '

    healthcheck:
      test: ["CMD-SHELL", "python -c 'import torch, jax; assert torch.cuda.is_available(); assert len([d for d in jax.devices() if \"gpu\" in str(d).lower()]) > 0' 2>/dev/null || exit 1"]
      interval: 60s
      timeout: 30s
      retries: 3
      start_period: 120s

    labels:
      - "com.docker.compose.project=${ENV_NAME:-docker_dev_template}"
      - "com.docker.compose.service=datascience"
      - "description=RTX 4090 GPU Dev Environment (PyTorch+JAX) - CUDA 12.4"

  mlflow:
    container_name: ${ENV_NAME:-docker_dev_template}_mlflow
    image: ghcr.io/mlflow/mlflow:latest
    command: >
      mlflow server
      --host 0.0.0.0
      --port 5000
      --backend-store-uri sqlite:///mlflow.db
      --default-artifact-root /mlflow_artifacts
    environment:
      MLFLOW_EXPERIMENTS_DEFAULT_ARTIFACT_LOCATION: /mlflow_artifacts
    volumes:
      - ./mlruns:/mlflow_artifacts
      - ./mlflow_db:/mlflow_db
    ports:
      - "${HOST_MLFLOW_PORT:-5000}:5000"
    restart: unless-stopped
    healthcheck:
      test: ["CMD", "python", "-c", "import requests; requests.get('http://localhost:5000/health').raise_for_status()"]
      interval: 30s
      timeout: 10s
      retries: 3
      start_period: 30s

volumes:
  uv-cache:


Overwriting ../../docker-compose.yml


In [None]:
%%writefile ../../pyproject.toml
[project]
name = "docker_dev_template"
version = "0.1.0"
description = "Hierarchical Bayesian modeling for baseball exit velocity data"
authors = [
  { name = "Marlins Data Science Team" },
]
license = "MIT"
readme = "README.md"

# ─── Restrict to Python 3.10–3.12 ──────────────────────────────
requires-python = ">=3.10,<3.13"

dependencies = [
  "pandas>=2.0",
  "numpy>=1.20,<2",
  "matplotlib>=3.4.0",
  "scikit-learn>=1.4.2",
  "pymc>=5.0.0",
  "arviz>=0.14.0",
  "statsmodels>=0.13.0",
  "jupyterlab>=3.0.0",
  "seaborn>=0.11.0",
  "tabulate>=0.9.0",
  "shap>=0.40.0",
  "xgboost>=1.5.0",
  "lightgbm>=3.3.0",
  "catboost>=1.0.0",
  "scipy>=1.7.0",
  "shapash[report]>=2.3.0",
  "shapiq>=1.3.0",
  "explainerdashboard>=0.3.0",
  "ipywidgets>=8.0.0",
  "nutpie>=0.7.1",
  "numpyro>=0.18.0,<1.0.0",
  "jax>=0.4.23",
  "jaxlib>=0.4.23",
  "pytensor>=2.18.3",
  "aesara>=2.9.4",
  "tqdm>=4.67.0",
  "pyarrow>=12.0.0",
  "streamlit>=1.20.0",
  "sqlalchemy>=1.4",
  "mysql-connector-python>=8.0",
  "optuna>=4.3.0",
  "bayesian-optimization>=1.2.0",
  "pretty_errors>=1.2.0",
  "gdown>=4.0.0",
  "invoke>=2.2",
  # ▶ Video download stack
  #   - pytube main-branch until next PyPI release (optional fallback)
  "pytube @ git+https://github.com/pytube/pytube",
  "yt-dlp>=2024.12.0",
  #   - optional convenience wrapper (does NOT install ffmpeg binary!)
  "ffmpeg-python >= 0.2.0",

  # Ultralytics YOLO (SOTA object detection, segmentation, etc.)
  # ▶ Computer vision
  "ultralytics==8.3.158",
  "opencv-python-headless>=4.10.0",
  "roboflow>=1.0.0",
  "mlflow>=3.1.1,<4.0.0",
  "optuna-integration[mlflow]>=4.4.0,<5.0.0",
  
  # PyTorch core libraries - platform specific with PEP-508 compliant syntax
  # CUDA wheels for Windows/Linux, CPU for macOS
  "torch>=2.0.0",
  "torchvision>=0.15.0",
  "torchaudio>=2.0.0",
  
  # new for basemodels
  "pydantic>=2.0.0",
  "pydantic-settings>=2.0.0",
]

[project.optional-dependencies]
dev = [
  "pytest>=7.0.0",
  "black>=23.0.0",
  "isort>=5.0.0",
  "flake8>=5.0.0",
  "mypy>=1.0.0",
  "pre-commit>=3.0.0",
]

cuda = [
  "cupy-cuda12x>=12.0.0",  # For CUDA 12.x
]

# ─── uv configuration ──────────────────────────────────────────
[tool.uv]                   # uv reads this block
index-strategy = "unsafe-best-match"

# Define named indexes for PyTorch CUDA variants
[[tool.uv.index]]
name = "pytorch-cu121"
url = "https://download.pytorch.org/whl/cu121"
explicit = true

[[tool.uv.index]]
name = "pytorch-cu118"
url = "https://download.pytorch.org/whl/cu118"
explicit = true

[[tool.uv.index]]
name = "pytorch-cu124"
url = "https://download.pytorch.org/whl/cu124"
explicit = true

[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

# Removed unsupported option: torch-backend requires uv ≥0.5.3
# To re-enable, first run: pip install -U uv>=0.5.3
[tool.uv.pip]
# (No unsupported keys here; configure only valid pip options.)

# Map PyTorch dependencies to CUDA indexes for non-macOS platforms
# Testing with CUDA 12.8
[tool.uv.sources]
torch = [
  { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
torchvision = [
  { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]
torchaudio = [
  { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
]

[tool.pytensor]
device    = "cuda"
floatX    = "float32"
allow_gc  = true
optimizer = "fast_run"


Overwriting ../../pyproject.toml


In [7]:
%%writefile ../../.devcontainer/validate_gpu.py
#!/usr/bin/env python3
"""
GPU validation and environment diagnostics for RTX 4090 devcontainer.
Focus: verify JAX and PyTorch access to CUDA, report common misconfigurations.
"""
import sys
import os
import subprocess
import warnings
import textwrap
import re
warnings.filterwarnings('ignore')


def print_section(title: str) -> None:
    print("\n" + "=" * 60)
    print(f"  {title}")
    print("=" * 60)


def validate_environment_variables() -> bool:
    """Validate JAX‑related environment variables (no inline comments, valid types)."""
    print_section("JAX ENVIRONMENT VARIABLE VALIDATION")

    jax_numeric_vars = {
        'XLA_PYTHON_CLIENT_MEM_FRACTION': {'type': 'float', 'range': (0.0, 1.0)},
        'JAX_PREALLOCATION_SIZE_LIMIT_BYTES': {'type': 'int', 'range': (0, None)},
    }
    jax_string_vars = {
        'XLA_FLAGS', 'JAX_PLATFORM_NAME', 'XLA_PYTHON_CLIENT_ALLOCATOR', 'XLA_PYTHON_CLIENT_PREALLOCATE'
    }

    ok = True
    problems = []

    for var, cfg in jax_numeric_vars.items():
        value = os.environ.get(var)
        print(f"\nCheck {var} -> {value}")
        if value is None:
            print("  not set; defaults apply")
            continue
        if '#' in value:
            clean = value.split('#')[0].strip()
            print("  contains inline comment; use:", clean)
            problems.append((var, value, clean))
            ok = False
            continue
        try:
            if cfg['type'] == 'float':
                v = float(value)
                low, high = cfg['range']
                if (low is not None and v < low) or (high is not None and v > high):
                    print("  out of recommended range")
                else:
                    print("  ok")
            else:
                v = int(value)
                print("  ok")
        except ValueError as e:
            print("  invalid numeric value:", e)
            ok = False

    for var in jax_string_vars:
        value = os.environ.get(var)
        if value and '#' in value:
            print(f"warn: {var} contains '#', which can break parsing")

    if problems:
        print("\nFix suggestions:")
        for var, bad, clean in problems:
            print(f"export {var}={clean}")
    return ok


def check_environment() -> None:
    print_section("ENVIRONMENT CHECK")
    print("python:", sys.executable)
    print("version:", sys.version)
    print("VIRTUAL_ENV:", os.environ.get('VIRTUAL_ENV'))
    print("PATH contains .venv:", '.venv/bin' in os.environ.get('PATH', ''))

    cuda_vars = ['CUDA_HOME', 'CUDA_PATH', 'CUDA_VISIBLE_DEVICES', 'LD_LIBRARY_PATH', 'NVIDIA_VISIBLE_DEVICES']
    print("\nCUDA variables:")
    for var in cuda_vars:
        print(f"  {var}:", os.environ.get(var, 'not set'))

    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=name,driver_version,memory.total', '--format=csv,noheader'],
            capture_output=True, text=True
        )
        if result.returncode == 0:
            print("\nGPU:", result.stdout.strip())
        else:
            print("\nwarn: nvidia-smi returned non‑zero")
    except FileNotFoundError:
        print("\nwarn: nvidia-smi not found in path")


def test_pytorch() -> bool:
    print_section("PYTORCH GPU TEST")
    try:
        import torch
        print("version:", torch.__version__)
        print("cuda available:", torch.cuda.is_available())
        if torch.cuda.is_available():
            print("device count:", torch.cuda.device_count())
            print("device 0:", torch.cuda.get_device_name(0))
            # quick matmul
            import time
            dev = torch.device('cuda')
            x = torch.randn(2000, 2000, device=dev)
            y = torch.randn(2000, 2000, device=dev)
            _ = x @ y
            torch.cuda.synchronize()
            t0 = time.time()
            r = x @ y
            torch.cuda.synchronize()
            print("matmul elapsed s:", round(time.time() - t0, 3))
            _ = r.sum().item()
            return True
        return False
    except Exception as e:
        print("pytorch test error:", e)
        return False


def check_cudnn_compatibility() -> bool:
    """Check CuDNN version compatibility between PyTorch and JAX."""
    print_section("CUDNN COMPATIBILITY CHECK")
    try:
        import torch
        import subprocess
        import glob
        
        # Check PyTorch CuDNN version
        pytorch_cudnn = torch.backends.cudnn.version()
        print(f"PyTorch CuDNN version: {pytorch_cudnn}")
        
        # Check installed nvidia-cudnn-cu12 package version
        try:
            result = subprocess.run(['uv', 'pip', 'list'], capture_output=True, text=True)
            if result.returncode == 0:
                lines = result.stdout.split('\n')
                for line in lines:
                    if 'nvidia-cudnn-cu12' in line:
                        print(f"Installed CuDNN package: {line.strip()}")
                        break
        except Exception as e:
            print(f"Could not check CuDNN package version: {e}")
        
        # Check CuDNN library files
        cudnn_paths = [
            "/app/.venv/lib/python3.10/site-packages/nvidia/cudnn/lib",
            "/usr/local/cuda/lib64",
            "/usr/lib/x86_64-linux-gnu"
        ]
        
        print("\nCuDNN library search:")
        for path in cudnn_paths:
            if os.path.exists(path):
                cudnn_libs = glob.glob(f"{path}/libcudnn*")
                if cudnn_libs:
                    print(f"  {path}: {len(cudnn_libs)} CuDNN libraries found")
                    for lib in cudnn_libs[:3]:  # Show first 3
                        print(f"    - {os.path.basename(lib)}")
                else:
                    print(f"  {path}: No CuDNN libraries found")
            else:
                print(f"  {path}: Path does not exist")
        
        # Check LD_LIBRARY_PATH
        ld_path = os.environ.get('LD_LIBRARY_PATH', '')
        print(f"\nLD_LIBRARY_PATH: {ld_path}")
        
        # Version compatibility check
        if pytorch_cudnn < 9000:  # Assuming version format like 9100 for 9.1.0
            print("WARNING: PyTorch CuDNN version may be too old for JAX")
            return False
        
        print("CuDNN compatibility check passed")
        return True
        
    except Exception as e:
        print(f"CuDNN compatibility check failed: {e}")
        return False


def test_jax_initialization() -> bool:
    print_section("JAX INITIALIZATION TEST")
    try:
        os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
        import jax
        import jaxlib
        from jaxlib import xla_client
        print("jax:", jax.__version__, "jaxlib:", jaxlib.__version__)
        
        # Check for CuDNN version mismatch errors
        try:
            opts = xla_client.generate_pjrt_gpu_plugin_options()
            print("gpu plugin options ok; memory_fraction:", opts.get('memory_fraction', 'not set'))
        except Exception as e:
            print("gpu plugin options error:", e)
            if "could not convert string to float" in str(e):
                print("hint: check XLA_PYTHON_CLIENT_MEM_FRACTION for inline comments")
            elif "CuDNN" in str(e) and "version" in str(e):
                print("hint: CuDNN version mismatch detected - check compatibility")
            return False
        return True
    except Exception as e:
        print("jax init error:", e)
        if "CuDNN" in str(e):
            print("hint: CuDNN-related error detected")
        return False


def test_jax() -> bool:
    print_section("JAX GPU TEST")
    try:
        os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
        import jax, jax.numpy as jnp
        from jax.lib import xla_bridge
        print("backend:", xla_bridge.get_backend().platform)
        devices = jax.devices()
        print("devices:", devices)
        gpus = [d for d in devices if 'gpu' in str(d).lower() or getattr(d, 'platform', '') == 'gpu']
        if not gpus:
            print("no gpu devices detected by jax")
            return False
        # quick compute
        import time
        key = jax.random.PRNGKey(0)
        x = jax.random.normal(key, (2000, 2000))
        x = jax.device_put(x, gpus[0])
        t0 = time.time()
        s = jnp.sum(x @ x).block_until_ready()
        print("matmul elapsed s:", round(time.time() - t0, 3), "sum:", float(s))
        return True
    except Exception as e:
        print("jax test error:", e)
        return False


def main() -> int:
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument('--quick', action='store_true')
    p.add_argument('--fix', action='store_true', help='Run with fix recommendations')
    args = p.parse_args()

    if args.quick:
        env_ok = validate_environment_variables()
        pt_ok = test_pytorch()
        return 0 if (env_ok and pt_ok) else 1

    env_ok = validate_environment_variables()
    check_environment()
    cudnn_ok = check_cudnn_compatibility()
    jax_init_ok = test_jax_initialization()
    jax_ok = test_jax()
    pt_ok = test_pytorch()

    print_section("SUMMARY")
    print("env vars:", "ok" if env_ok else "fail")
    print("cudnn compatibility:", "ok" if cudnn_ok else "fail")
    print("jax init:", "ok" if jax_init_ok else "fail")
    print("jax compute:", "ok" if jax_ok else "fail")
    print("pytorch:", "ok" if pt_ok else "fail")

    # Provide fix recommendations if requested
    if args.fix and not (env_ok and cudnn_ok and jax_init_ok and jax_ok and pt_ok):
        print_section("FIX RECOMMENDATIONS")
        if not cudnn_ok:
            print("1. CuDNN version mismatch detected:")
            print("   - Upgrade nvidia-cudnn-cu12 to version >= 9.8.0")
            print("   - Ensure LD_LIBRARY_PATH includes CuDNN library paths")
        if not jax_init_ok:
            print("2. JAX initialization failed:")
            print("   - Check CuDNN compatibility")
            print("   - Verify XLA environment variables (no inline comments)")
        if not jax_ok:
            print("3. JAX GPU computation failed:")
            print("   - Verify GPU is accessible")
            print("   - Check CUDA driver compatibility")

    return 0 if (env_ok and cudnn_ok and jax_init_ok and jax_ok and pt_ok) else 1


if __name__ == '__main__':
    sys.exit(main())


Overwriting ../../.devcontainer/validate_gpu.py


In [8]:
%%writefile ../../.devcontainer/tests/test_pytorch_gpu.py
#!/usr/bin/env python3
"""Small PyTorch GPU benchmark."""
import time


def test_pytorch(force_cpu: bool = False) -> None:
    import torch
    cuda_ok = torch.cuda.is_available() and not force_cpu
    if cuda_ok:
        name = torch.cuda.get_device_name(0)
        major, minor = torch.cuda.get_device_capability()
        print(f"device: {name} (sm_{major}{minor:02d})")
        device = torch.device("cuda:0")
    else:
        print("falling back to cpu")
        device = torch.device("cpu")

    size = (1000, 1000)
    a, b = (torch.randn(size, device=device) for _ in range(2))
    _ = a @ b
    t0 = time.time()
    _ = (a @ b).sum().item()
    if device.type == "cuda":
        torch.cuda.synchronize()
    print(f"matmul on {device} took {(time.time()-t0)*1000:.2f} ms")


if __name__ == "__main__":
    test_pytorch()


Overwriting ../../.devcontainer/tests/test_pytorch_gpu.py


In [9]:
%%writefile ../../.devcontainer/tests/test_uv.py
"""UV and key package presence check."""
import subprocess
import sys

print("UV version:")
try:
    r = subprocess.run(["uv", "--version"], capture_output=True, text=True)
    print(r.stdout.strip() or r.stderr.strip())
except FileNotFoundError:
    print("uv not found")

print("\nPython:")
print(sys.executable)
print(sys.version)

print("\nKey packages:")
for pkg in ["numpy", "pandas", "matplotlib", "scipy", "sklearn", "jupyterlab", "seaborn", "tqdm"]:
    try:
        if pkg == "sklearn":
            import sklearn as m
        else:
            m = __import__(pkg)
        print(pkg, getattr(m, "__version__", "unknown"))
    except Exception as e:
        print(pkg, "missing or error:", e)


Overwriting ../../.devcontainer/tests/test_uv.py


In [10]:
%%writefile ../../.devcontainer/tests/test_pytorch.py
print("PyTorch quick check")
try:
    import torch
    print("version:", torch.__version__)
    print("cuda:", torch.cuda.is_available())
    if torch.cuda.is_available():
        print("devices:", torch.cuda.device_count())
        for i in range(torch.cuda.device_count()):
            print(i, torch.cuda.get_device_name(i))
        x = torch.ones(100, 100, device='cuda:0')
        print("sum:", float(torch.sum(x)))
except Exception as e:
    print("error:", e)


Overwriting ../../.devcontainer/tests/test_pytorch.py


In [11]:
%%writefile ../../.devcontainer/tests/test_uv.py
# Test other critical packages
print("\n📦 Testing other critical packages...")

packages_to_test = [
    'numpy', 'pandas', 'matplotlib', 'scipy', 'sklearn', 
    'jupyterlab', 'seaborn', 'tqdm'
]

for package in packages_to_test:
    try:
        if package == 'sklearn':
            import sklearn
            version = sklearn.__version__
        else:
            module = __import__(package)
            version = getattr(module, '__version__', 'unknown')
        print(f"   ✅ {package}: {version}")
    except ImportError:
        print(f"   ❌ {package}: Not installed")
    except Exception as e:
        print(f"   ⚠️  {package}: Error - {e}")


Overwriting ../../.devcontainer/tests/test_uv.py


In [12]:
%%writefile ../../.devcontainer/tests/test_summary.py
#!/usr/bin/env python3
"""Aggregated checks for the devcontainer layout and GPU readiness."""
import os
import sys
import time
import subprocess


def section(t):
    print("\n" + "=" * 60)
    print(t)
    print("=" * 60)


def test_structure() -> bool:
    section("STRUCTURE")
    expected = [
        '/workspace/docker-compose.yml',
        '/workspace/pyproject.toml',
        '/workspace/.devcontainer/devcontainer.json',
        '/workspace/.devcontainer/Dockerfile',
        '/workspace/.devcontainer/.env.template',
        '/workspace/.devcontainer/.dockerignore',
        '/app/validate_gpu.py',
        '/app/tests/'
    ]
    ok = True
    for p in expected:
        if os.path.exists(p):
            print("ok:", p)
        else:
            print("missing:", p)
            ok = False
    return ok


def test_uv() -> bool:
    section("UV")
    try:
        r = subprocess.run(['uv', '--version'], capture_output=True, text=True)
        print(r.stdout.strip() or r.stderr.strip())
        return r.returncode == 0
    except FileNotFoundError:
        print('uv not in PATH')
        return False


def test_pytorch() -> bool:
    section("PYTORCH")
    try:
        import torch
        print("version:", torch.__version__)
        print("cuda:", torch.cuda.is_available())
        if torch.cuda.is_available():
            d = torch.device('cuda:0')
            x = torch.ones(512, 512, device=d)
            y = torch.sum(x)
            print("sum:", y.item())
            return True
        return False
    except Exception as e:
        print("error:", e)
        return False


def test_jax() -> bool:
    section("JAX")
    try:
        import jax, jax.numpy as jnp

        # Show all devices for visibility
        devs = jax.devices()
        print("devices:", devs)

        # Prefer the supported filtered query
        gpus = jax.devices("gpu")

        # Fallback for older/newer renderings (e.g., "CudaDevice(id=0)")
        if not gpus:
            gpus = [
                d for d in devs
                if getattr(d, "platform", "").lower() in {"gpu", "cuda"} or "cuda" in str(d).lower()
            ]

        if not gpus:
            print("no gpu devices detected by jax")
            return False

        # Tiny compute on the first GPU to ensure execution
        x = jnp.ones((512, 512), dtype=jnp.float32)
        x = jax.device_put(x, gpus[0])
        s = jnp.sum(x).block_until_ready()
        print("sum:", float(s))
        return True
    except Exception as e:
        print("error:", e)
        return False



def main() -> int:
    s_ok = test_structure()
    uv_ok = test_uv()
    pt_ok = test_pytorch()
    j_ok = test_jax()

    section("SUMMARY")
    print("structure:", s_ok, "uv:", uv_ok, "pytorch:", pt_ok, "jax:", j_ok)
    return 0 if all([s_ok, uv_ok, pt_ok, j_ok]) else 1


if __name__ == '__main__':
    sys.exit(main())

Overwriting ../../.devcontainer/tests/test_summary.py
