### Ensure no other containers are running for dev container, if they are stop and remove them


In [1]:
%%writefile .devcontainer/.dockerignore
**/.git
**/.vscode
**/.idea
**/__pycache__
**/*.pyc
**/*.pyo
**/*.pyd
**/*.swp
**/venv
**/env
.env
*.code-workspace
data/
notebooks/**/*.ipynb_checkpoints
*.log
.DS_Store
Thumbs.db

# --- NEW: keep build context tiny & readable ---
mlruns/          # MLflow runs & artifacts (often huge, root-owned)
mlruns/**        # ensure nested paths are ignored



Overwriting .devcontainer/.dockerignore


In [2]:
%%writefile .dockerignore
**/.git
**/.vscode
**/.idea
**/__pycache__
**/*.pyc
**/*.pyo
**/*.pyd
**/*.swp
**/venv
**/env
.env
*.code-workspace
data/
notebooks/**/*.ipynb_checkpoints
*.log
.DS_Store
Thumbs.db

# --- NEW: keep build context tiny & readable ---
mlruns/          # MLflow runs & artifacts (often huge, root-owned)
mlruns/**        # ensure nested paths are ignored


Overwriting .dockerignore


In [3]:
%%writefile .env.template 
ENV_NAME=docker_dev_template
CUDA_TAG=12.8.0
DOCKER_BUILDKIT=1
HOST_JUPYTER_PORT=8890
HOST_TENSORBOARD_PORT=6008
HOST_EXPLAINER_PORT=8050
HOST_STREAMLIT_PORT=8501
HOST_MLFLOW_PORT=5000
PYTHON_VER=3.10
JAX_PLATFORM_NAME=gpu
XLA_PYTHON_CLIENT_PREALLOCATE=true
XLA_PYTHON_CLIENT_ALLOCATOR=platform
XLA_PYTHON_CLIENT_MEM_FRACTION=0.95
XLA_FLAGS=--xla_force_host_platform_device_count=1
JAX_DISABLE_JIT=false
JAX_ENABLE_X64=false
TF_FORCE_GPU_ALLOW_GROWTH=false
JAX_PREALLOCATION_SIZE_LIMIT_BYTES=8589934592
# Snowflake config must go in a separate file or devcontainer.env



Overwriting .env.template


In [4]:
%%writefile .cursor/settings.json
{
    "http.proxySupport": "off",
    "update.mode": "manual",
    "extensions.autoUpdate": false,
    "extensions.autoCheckUpdates": false,
    "python.defaultInterpreterPath": "/app/.venv/bin/python",
    "jupyter.interactiveWindow.textEditor.executeSelection": true,
    "jupyter.widgetScriptSources": ["jsdelivr.com", "unpkg.com"],
    "jupyter.experiments.enabled": false,
    "jupyter.telemetry.enabled": false,
    "python.telemetry.enabled": false,
    "telemetry.telemetryLevel": "off"
} 

Overwriting .cursor/settings.json


In [5]:
%%writefile .devcontainer/devcontainer.json
{
  "name": "docker_dev_template_uv",
  "dockerComposeFile": ["../docker-compose.yml"],
  "service": "datascience",
  "workspaceFolder": "/workspace",
  "shutdownAction": "stopCompose",
  "runArgs": ["--gpus", "all"],
  "customizations": {
    "vscode": {
      "extensions": [
        "ms-python.python",
        "ms-python.vscode-pylance",
        "ms-toolsai.jupyter",
        "ms-toolsai.jupyter-renderers"
      ],
      "settings": {
        // 1. COMPREHENSIVE TELEMETRY SETTINGS
        "telemetry.telemetryLevel": "off",
        "python.telemetry.enabled": false,
        "jupyter.telemetry.enabled": false,
        "jupyter.experiments.enabled": false,
        "update.mode": "manual",
        "extensions.autoUpdate": false,
        "extensions.autoCheckUpdates": false,

        // 2. MOVE HEAVY EXTENSIONS TO LOCAL UI HOST
        "remote.extensionKind": {
          "ms-python.python": ["ui"],
          "ms-python.vscode-pylance": ["ui"],
          "ms-toolsai.jupyter": ["ui"],
          "ms-toolsai.jupyter-renderers": ["ui"]
        },

        // 3. PYTHON AND JUPYTER SETTINGS
        "python.defaultInterpreterPath": "/workspace/.venv/bin/python",
        "jupyter.interactiveWindow.textEditor.executeSelection": true,
        "jupyter.widgetScriptSources": ["jsdelivr.com", "unpkg.com"]
      }
    }
  },
  "remoteEnv": {
    "JUPYTER_ENABLE_LAB": "true"
  },

  // After container creation, set up env, check UV, Python, and key libs
  "postCreateCommand": [
    "/bin/sh",
    "-c",
    ".devcontainer/setup_env.sh && \\\necho '## uv diagnostics ##' && uv --version && \\\necho '## python ##' && which python && python -V && \\\nexec .devcontainer/verify_env.py"
  ]
}


Overwriting .devcontainer/devcontainer.json


In [6]:
%%writefile .devcontainer/Dockerfile
# .devcontainer/Dockerfile — uv‑based replacement for the previous Conda image
# -----------------------------------------------------------------------------
# CUDA + cuDNN base with drivers already installed --------------------------------
ARG CUDA_TAG=12.8.0              # <── single source of truth
FROM nvidia/cuda:${CUDA_TAG}-cudnn-devel-ubuntu22.04

# ---------- build-time ARGs ---------------------------------------------------
ARG PYTHON_VER=3.10
ARG ENV_NAME=docker_dev_template
ARG JAX_PREALLOCATE=true
ARG JAX_MEM_FRAC=0.95
ARG JAX_ALLOCATOR=platform
ARG JAX_PREALLOC_LIMIT=8589934592
ENV DEBIAN_FRONTEND=noninteractive

# ----------------------------------------------------------------------------
# 1) Core OS deps, build tools, & Python (system) -----------------------------
RUN --mount=type=cache,target=/var/cache/apt \
    --mount=type=cache,target=/var/lib/apt \
    apt-get update && apt-get install -y --no-install-recommends \
        bash curl ca-certificates git procps htop util-linux build-essential \
        python3 python3-venv python3-pip python3-dev \
        autoconf automake libtool m4 cmake pkg-config \
        jags iproute2 net-tools lsof \
        && pkg-config --modversion jags \
        && apt-get clean && rm -rf /var/lib/apt/lists/*

# Install Node.js for VS Code remote extension host
RUN curl -fsSL https://deb.nodesource.com/setup_18.x | bash - && \
    apt-get update && apt-get install -y nodejs && \
    rm -rf /var/lib/apt/lists/*

# ----------------------------------------------------------------------------
# 2) Copy a *pinned* uv & uvx binary pair from the official distroless image --
COPY --from=ghcr.io/astral-sh/uv:0.7.12 /uv /uvx /bin/

# ----------------------------------------------------------------------------
# 3) Create project dir & copy only the lock/manifest for best layer‑caching --
WORKDIR /app
COPY pyproject.toml uv.lock* ./

# ----------------------------------------------------------------------------
# 4) Create an in-project venv, install deps, then symlink into /workspace
RUN --mount=type=cache,target=/root/.cache/uv \
    mkdir -p /workspace && \
    uv venv .venv --python "${PYTHON_VER}" --prompt "${ENV_NAME}" && \
    (uv sync --locked || (echo "⚠️  Lock drift detected – regenerating" \
        && uv lock --upgrade --quiet && uv sync)) && \
    ln -s /app/.venv /workspace/.venv

# Promote venv for all later layers ------------------------------------------------
ENV VIRTUAL_ENV=/app/.venv
ENV PATH="/app/.venv/bin:${PATH}"

# ----------------------------------------------------------------------------
# 5) ---------- CUDA wheels -------------------------------------------------------
RUN --mount=type=cache,target=/root/.cache/uv \
    uv pip install --pre --no-cache-dir \
        torch torchvision torchaudio \
        --index-url https://download.pytorch.org/whl/nightly/cu128 && \
    uv pip install --no-cache-dir \
        "jax[cuda12]==0.6.0" \
        -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# --- CUDA toolkit sanity check (robust for runtime *and* devel images) ------
RUN set -e; \
    # 1️⃣ First try: any cuda-<ver> folder?
    CUDA_REAL="$(ls -d /usr/local/cuda-* 2>/dev/null | sort -V | tail -n1 || true)"; \
    # 2️⃣ Fallback: flat layout shipped by some runtime images
    if [ -z "$CUDA_REAL" ] && [ -d /usr/local/cuda ]; then \
        CUDA_REAL="/usr/local/cuda"; \
    fi; \
    # 3️⃣ Bail if still empty
    if [ -z "$CUDA_REAL" ]; then \
        echo '❌  No CUDA toolkit folder found — aborting.' >&2; exit 1; \
    fi; \
    # 4️⃣ Refresh the canonical symlink only when needed
    if [ "$CUDA_REAL" != "/usr/local/cuda" ]; then \
        echo "🔧  Linking /usr/local/cuda -> $CUDA_REAL"; \
        ln -sfn "$CUDA_REAL" /usr/local/cuda; \
    fi; \
    echo "🟢  CUDA toolkit detected at $CUDA_REAL"

# ----------------------------------------------------------------------------
# 6) Install PyJAGS with the cstdint header work‑around -----------------------
RUN CPPFLAGS="-include cstdint" uv pip install --no-build-isolation pyjags==1.3.8

# ----------------------------------------------------------------------------
# 7) Copy *rest* of the project after deps → fast rebuild when code changes ---
COPY . /app

# ----------------------------------------------------------------------------
# 8) GPU‑tuning env vars (carried forward from Conda‑based image) -------------
ENV XLA_PYTHON_CLIENT_PREALLOCATE=${JAX_PREALLOCATE}
ENV XLA_PYTHON_CLIENT_MEM_FRACTION=${JAX_MEM_FRAC}
ENV XLA_PYTHON_CLIENT_ALLOCATOR=${JAX_ALLOCATOR}
ENV JAX_PLATFORM_NAME=gpu
ENV XLA_FLAGS="--xla_force_host_platform_device_count=1"
ENV JAX_DISABLE_JIT=false
ENV JAX_ENABLE_X64=false
ENV TF_FORCE_GPU_ALLOW_GROWTH=false
ENV JAX_PREALLOCATION_SIZE_LIMIT_BYTES=${JAX_PREALLOC_LIMIT}

# ----------------------------------------------------------------------------
# 9) Library path so PyJAGS & CUDA libs resolve correctly ---------------------
ENV LD_LIBRARY_PATH="/app/.venv/lib:${LD_LIBRARY_PATH}"

# ----------------------------------------------------------------------------
# 10) Final working directory & default command ------------------------------
WORKDIR /workspace
CMD ["bash"]

# 11) Force login shells & VS Code terminals to land in /workspace
RUN echo 'cd /workspace' > /etc/profile.d/99-workspace-cd.sh

# 12) Force every IPython / Jupyter kernel to start in /workspace
RUN mkdir -p /root/.ipython/profile_default/startup && \
    printf "import os, sys\nos.chdir('/workspace')\nsys.path.append('/workspace')\n" \
      > /root/.ipython/profile_default/startup/00-cd-workspace.py

# 13) Auto-activate uv venv in every login shell
RUN echo '. /app/.venv/bin/activate' > /etc/profile.d/10-uv-activate.sh





Overwriting .devcontainer/Dockerfile


In [7]:
%%writefile .devcontainer/verify_env.py
# In your host terminal:
cat << 'EOF' > .devcontainer/verify_env.py
#!/usr/bin/env python3
import encodings, jupyterlab, torch, jax, sys, os

print("## Python & library diagnostics ##")
print("Python:", sys.executable, sys.version.split()[0])
print("🟢 encodings OK")
print("🟢 jupyterlab OK")
print("🟢 torch", torch.__version__, "CUDA:", torch.cuda.is_available())
print("🟢 jax", jax.__version__, "devices:", jax.devices())
EOF
chmod +x .devcontainer/verify_env.py


Overwriting .devcontainer/verify_env.py


In [8]:
%%writefile .devcontainer/setup_env.sh
#!/usr/bin/env sh
# Copy the template only on first run so local secrets are not overwritten
set -eu
if [ ! -f /workspace/.env ]; then
  echo "📝  Generating default .env from template"
  cp /workspace/.env.template /workspace/.env
fi 

Overwriting .devcontainer/setup_env.sh


In [9]:
%%writefile .devcontainer/gpu_verify.py
#!/usr/bin/env python3
"""
Verify that the GPU is accessible and JAX is correctly configured.
This script is used during container startup.
"""

import sys

def check_gpu():
    print("Checking GPU availability...")
    try:
        import jax
        jax.config.update('jax_platform_name', 'gpu')
        
        # Get device count and details
        devices = jax.devices()
        device_count = len(devices)
        print(f"JAX version: {jax.__version__}")
        print(f"Available devices: {device_count}")
        
        for i, device in enumerate(devices):
            print(f"Device {i}: {device}")
        
        if device_count == 0 or 'gpu' not in str(devices[0]).lower():
            print("WARNING: No GPU devices found by JAX!")
            return False
        
        # Check CUDA configuration
        import jax.tools.jax_jit
        jit_info = jax.tools.jax_jit.get_jax_jit_flags()
        print(f"JIT configuration: {jit_info}")
        
        # Run a simple GPU computation
        print("Running a test computation on GPU...")
        import numpy as np
        x = np.ones((1000, 1000))
        result = jax.numpy.sum(x, axis=0)
        print(f"Test computation result shape: {result.shape}")
        
        print("JAX GPU verification completed successfully!")
        return True
    
    except ImportError:
        print("JAX not found! Make sure JAX is installed with GPU support.")
        return False
    except Exception as e:
        print(f"Error during GPU verification: {e}")
        return False

if __name__ == "__main__":
    success = check_gpu()
    if not success:
        print("WARNING: GPU verification failed!")
        # Not exiting with error to allow container to start anyway
        # sys.exit(1)
    else:
        sys.exit(0) 

Overwriting .devcontainer/gpu_verify.py


In [10]:
%%writefile .devcontainer/jags_verify.py
#!/usr/bin/env python3
"""
Verify that PyJAGS is correctly installed and working.
This script is used by the Docker container health check.
"""

import sys
try:
    import pyjags
    print(f"PyJAGS version: {pyjags.__version__}")
    
    # Create a simple model to verify that PyJAGS works
    code = """
    model {
        # Likelihood
        y ~ dnorm(mu, 1/sigma^2)
        
        # Priors
        mu ~ dnorm(0, 0.001)
        sigma ~ dunif(0, 100)
    }
    """
    
    # Sample data
    data = {'y': 0.5}
    
    # Initialize model with data
    model = pyjags.Model(code, data=data, chains=1, adapt=100)
    print("JAGS model initialized successfully!")
    
    # Sample from the model
    samples = model.sample(200, vars=['mu', 'sigma'])
    print("JAGS sampling completed successfully!")
    
    # Verify the samples
    mu_samples = samples['mu']
    sigma_samples = samples['sigma']
    print(f"mu mean: {mu_samples.mean():.4f}")
    print(f"sigma mean: {sigma_samples.mean():.4f}")
    
    print("PyJAGS verification completed successfully!")
    sys.exit(0)
    
except ImportError:
    print("PyJAGS not found!")
    sys.exit(1)
except Exception as e:
    print(f"Error during PyJAGS verification: {e}")
    sys.exit(1) 

Overwriting .devcontainer/jags_verify.py


In [11]:
%%writefile .devcontainer/pyjags_patch.py
#!/usr/bin/env python3
import os
import sys

def patch_pyjags_sources():
    print("Downloading and patching PyJAGS source...")
    os.system("pip download --no-binary :all: pyjags==1.3.8")
    os.system("tar -xzf pyjags-1.3.8.tar.gz")
    os.chdir("pyjags-1.3.8")
    
    # Add cstdint include to all cpp files
    for root, dirs, files in os.walk("src"):
        for file in files:
            if file.endswith(".cpp") or file.endswith(".h"):
                filepath = os.path.join(root, file)
                with open(filepath, 'r') as f:
                    content = f.read()
                if "#include <cstdint>" not in content:
                    with open(filepath, 'w') as f:
                        f.write("#include <cstdint>\n" + content)
                    print(f"Patched {filepath}")
    
    # Build and install
    os.system("pip install --no-build-isolation .")
    print("PyJAGS installation complete!")
    return 0

if __name__ == "__main__":
    sys.exit(patch_pyjags_sources()) 

Overwriting .devcontainer/pyjags_patch.py


In [12]:
%%writefile .pre-commit-config.yaml

repos:
  - repo: https://github.com/astral-sh/uv-pre-commit
    rev: 0.5.7  # Use the ref you want to point at
    hooks:
      - id: uv-lock      # keep uv.lock in sync
      - id: uv-export    
        args: [--extra=dev, --output-file=requirements-dev.txt]
  - repo: https://github.com/pre-commit/pre-commit-hooks
    rev: v4.6.0
    hooks:
      - id: trailing-whitespace
      - id: end-of-file-fixer
      - id: check-yaml
      - id: check-added-large-files 

Overwriting .pre-commit-config.yaml


In [13]:
%%writefile docker-compose.yml
# docker-compose.yml
name: ${ENV_NAME:-docker_dev_template}

services:
  datascience:
    build:
      context: .
      dockerfile: .devcontainer/Dockerfile
      args:
        PYTHON_VER: ${PYTHON_VER:-3.10}
        ENV_NAME: ${ENV_NAME:-docker_dev_template}
        JAX_PREALLOCATE: ${XLA_PYTHON_CLIENT_PREALLOCATE:-true}
        JAX_MEM_FRAC: ${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.95}
        JAX_ALLOCATOR: ${XLA_PYTHON_CLIENT_ALLOCATOR:-platform}
        JAX_PREALLOC_LIMIT: ${JAX_PREALLOCATION_SIZE_LIMIT_BYTES:-8589934592}

    # (Removed explicit container_name to avoid "already in use" conflicts.)

    # Enhanced restart policy to handle port conflicts
    restart: unless-stopped

    depends_on:
      mlflow:
        condition: service_healthy

    gpus: all

    environment:
      - PYTHON_VER=${PYTHON_VER}
      - NVIDIA_VISIBLE_DEVICES=all
      - NVIDIA_DRIVER_CAPABILITIES=compute,utility,graphics,display
      - JAX_PLATFORM_NAME=${JAX_PLATFORM_NAME}
      - XLA_PYTHON_CLIENT_PREALLOCATE=${XLA_PYTHON_CLIENT_PREALLOCATE}
      - XLA_PYTHON_CLIENT_ALLOCATOR=${XLA_PYTHON_CLIENT_ALLOCATOR}
      - XLA_PYTHON_CLIENT_MEM_FRACTION=${XLA_PYTHON_CLIENT_MEM_FRACTION}
      - XLA_FLAGS=${XLA_FLAGS}
      - JAX_DISABLE_JIT=${JAX_DISABLE_JIT}
      - JAX_ENABLE_X64=${JAX_ENABLE_X64}
      - TF_FORCE_GPU_ALLOW_GROWTH=${TF_FORCE_GPU_ALLOW_GROWTH}
      - JAX_PREALLOCATION_SIZE_LIMIT_BYTES=${JAX_PREALLOCATION_SIZE_LIMIT_BYTES}

    volumes:
      - .:/workspace
      - ./mlruns:/workspace/mlruns        # new

    ports:
      # Enhanced port configuration with fallback options
      - "${HOST_JUPYTER_PORT:-8890}:8888"
      - "${HOST_TENSORBOARD_PORT:-}:6008"
      - "${HOST_EXPLAINER_PORT:-8050}:8050"
      - "${HOST_STREAMLIT_PORT:-}:8501"

    command: >
      jupyter lab
        --ip=0.0.0.0
        --port=8888
        --allow-root
        --NotebookApp.token="${JUPYTER_TOKEN:-jupyter}"
        --NotebookApp.allow_origin='*'

    healthcheck:
      test: ["CMD-SHELL", "bash --version && uv --help || exit 1"]
      interval: 30s
      timeout: 5s
      retries: 3

    # Enhanced labels for better debugging
    labels:
      - "com.docker.compose.project=${ENV_NAME:-docker_dev_template}"
      - "com.docker.compose.service=datascience"
      - "description=AI/ML Development Environment with GPU Support"

  mlflow:
    image: ghcr.io/mlflow/mlflow:latest
    command: >
      mlflow server
      --host 0.0.0.0
      --port 5000
      --backend-store-uri sqlite:///mlflow.db
      --default-artifact-root /mlflow_artifacts
    environment:
      MLFLOW_EXPERIMENTS_DEFAULT_ARTIFACT_LOCATION: /mlflow_artifacts
    volumes:
      - ./mlruns:/mlflow_artifacts    # artifacts + run metadata
      - ./mlflow_db:/mlflow_db        # SQLite backend store
    ports:
      - "${HOST_MLFLOW_PORT:-5000}:5000"
    restart: unless-stopped
    healthcheck:
      test: ["CMD-SHELL",
             "python - <<'PY'\nimport requests,sys; requests.get('http://localhost:5000/health').raise_for_status()\nPY"]
      interval: 10s
      timeout: 3s
      retries: 5
      start_period: 30s






Overwriting docker-compose.yml


In [2]:
%%writefile pyproject.toml
[project]
name = "docker_dev_template"
version = "0.1.0"
description = "Pytorch and Jax GPU docker container"
authors = [
  { name = "Geoffrey Hadfield" },
]
license = "MIT"
readme = "README.md"

# ─── Restrict to Python 3.10–3.12 ──────────────────────────────
requires-python = ">=3.10,<3.13"

dependencies = [
  "pandas>=1.2.0",
  "numpy>=1.20.0",
  "matplotlib>=3.4.0",
  "mlflow>=2.10.2",
  "mlflow-skinny>=2.10.2",
  "scikit-learn>=1.4.2",
  "pymc>=5.0.0",
  "arviz>=0.14.0",
  "statsmodels>=0.13.0",
  "jupyterlab>=3.0.0",
  "seaborn>=0.11.0",
  "tabulate>=0.9.0",
  "shap>=0.40.0",
  "xgboost>=1.5.0",
  "lightgbm>=3.3.0",
  "catboost>=1.2.8,<1.3.0",
  "scipy>=1.7.0",
  "shapash[report]>=2.3.0",
  "shapiq>=0.1.0",
  "explainerdashboard==0.5.1",
  "ipywidgets>=8.0.0",
  "nutpie>=0.7.1",   # new: nutpie backend for PyMC
  "numpyro>=0.18.0,<1.0.0",
  "jax==0.6.0",
  "jaxlib==0.6.0",
  "pytensor>=2.18.3",  # explicit version for CUDA support
  "aesara>=2.9.4",     # alternative backend option
  "tqdm>=4.67.0",
  "pyarrow>=12.0.0",
  "optuna>=3.0.0",
  "optuna-integration[mlflow]>=0.2.0",
  "omegaconf>=2.3.0,<2.4.0",
  "hydra-core>=1.3.2,<1.4.0",
  "streamlit>=1.46.1,<2.0.0",
  "cloudpickle>=2.4.0",
]

[project.optional-dependencies]
dev = [
  "pytest>=7.0.0",
  "black>=23.0.0",
  "isort>=5.0.0",
  "flake8>=5.0.0",
  "mypy>=1.0.0",
  "invoke>=2.2",
]

cuda = [
  "cupy-cuda12x>=12.0.0",  # For CUDA 12.x
]

[tool.pytensor]
# Default configuration for PyTensor
device = "cuda"          # Use CUDA by default if available
floatX = "float32"       # Use float32 by default for better CUDA performance
allow_gc = true          # Allow garbage collection
optimizer = "fast_run"   # Fast run optimization by default


Overwriting pyproject.toml


In [15]:
%%writefile tasks.py
# tasks.py  ── invoke ≥2.2
from invoke import task, Context  # type: ignore
from typing import List, Optional, Union

import os
import sys
import pathlib
import tempfile
import datetime as _dt
import atexit
import socket
import contextlib
import errno


BASE_ENV = pathlib.Path(__file__).parent


# Track temporary env files for cleanup
_saved_env_files: List[str] = []


def _parse_port(port: Union[str, int, None]) -> Optional[int]:
    """
    Parse and validate a port number.
    
    Args:
        port: Port number as string or int, or None
        
    Returns:
        Validated port number as int, or None if input was None
        
    Raises:
        ValueError: If port is invalid or out of range
    """
    if port is None:
        return None
        
    try:
        port_int = int(port)
        if not (0 < port_int < 65536):
            raise ValueError(f"Port {port_int} out of valid range (1-65535)")
        return port_int
    except (TypeError, ValueError) as e:
        raise ValueError(f"Invalid port value: {port}") from e


def _first_free_port(start: int = 5200) -> int:
    """Return the first TCP port >= *start* that is unused on localhost."""
    print(f"DEBUG: Searching for free port starting at {start}")  # Debug
    import socket
    import contextlib
    for port in range(start, 65535):
        with contextlib.closing(socket.socket()) as s:
            if s.connect_ex(("127.0.0.1", port)):
                print(f"DEBUG: Found free port {port}")  # Debug
                return port
    raise RuntimeError("No free port found")


def _free_port(start=5200) -> int:
    """Find a free port by letting the OS assign one."""
    print(f"DEBUG: Finding free port starting at {start}")  # Debug
    import socket
    import contextlib
    with contextlib.closing(
        socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    ) as s:
        s.bind(('', 0))
        port = s.getsockname()[1]
        print(f"DEBUG: Found free port {port}")  # Debug
        return port


def _port_free(host: str, port: int, timeout: float = 0.1) -> bool:
    """
    Return True iff *host:port* is NOT in use.

    Uses a non-blocking TCP connect – works on Linux, macOS, Windows,
    inside or outside WSL – and does **not** rely on lsof / netstat.
    """
    print(f"DEBUG: Checking if port {port} is free on {host}")  # Debug
    try:
        with contextlib.closing(
            socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        ) as s:
            s.settimeout(timeout)
            s.connect((host, port))
            print(f"DEBUG: Port {port} is in use")  # Debug
            return False      # connection succeeded ⇒ something listening
    except (OSError, socket.timeout):
        print(f"DEBUG: Port {port} is free")  # Debug
        return True           # connection failed ⇒ port is free


def _find_port(preferred: int, start: int = 5200) -> int:
    """
    Try to use preferred port, fall back to finding first available port.
    
    Args:
        preferred: The preferred port number to try first
        start: Where to start searching if preferred port is taken
        
    Returns:
        An available port number
    """
    print(f"DEBUG: Trying preferred port {preferred}")  # Debug
    if _port_free("127.0.0.1", preferred):
        return preferred
    return _first_free_port(start)


def _write_envfile(name: str, 
                   ports: Optional[dict[str, int]] = None) -> pathlib.Path:
    """
    Create a throw-away .env file for the current `invoke up` run.
    
    Docker-compose will use this to see the chosen host-ports. We include all
    services we know about; anything unset falls back to .env.template defaults.
    """
    env_lines = [f"ENV_NAME={name}"]
    mapping = {
        "jupyter": "HOST_JUPYTER_PORT",
        "tensorboard": "HOST_TENSORBOARD_PORT",
        "explainer": "HOST_EXPLAINER_PORT",
        "streamlit": "HOST_STREAMLIT_PORT",
        "mlflow": "HOST_MLFLOW_PORT",      # NEW
    }
    for svc, var in mapping.items():
        if ports and svc in ports:
            env_lines.append(f"{var}={ports[svc]}")
    env_lines.append(f"# generated {_dt.datetime.now().isoformat()}")
    tmp = tempfile.NamedTemporaryFile(
        "w", 
        delete=False, 
        prefix=".env.",
        dir=BASE_ENV
    )
    tmp.write("\n".join(env_lines))
    tmp.close()
    _saved_env_files.append(tmp.name)
    return pathlib.Path(tmp.name)


# Register cleanup function
def _cleanup_env_files() -> None:
    """Remove all temporary env files."""
    for path in _saved_env_files:
        try:
            os.remove(path)
        except OSError:
            pass


atexit.register(_cleanup_env_files)


def _compose(
    c: Context,
    cmd: str,
    name: str,
    rebuild: bool = False,
    force_pty: bool = False,
    ports: Optional[dict[str, int]] = None,
) -> None:
    """
    Wrapper around `docker compose` that also sanity-checks host ports.
    """
    # ---------- NEW pre-flight check --------------------------------------
    if ports:
        for svc, port in ports.items():
            if port is None:
                continue
            if not _port_free("127.0.0.1", int(port)):
                print(f"❌  Host port {port} already bound – "
                      f"{svc} cannot start. Choose another port (invoke up "
                      f"--{svc}-port XXXXX) or free it first.")
                sys.exit(1)

    env = {**os.environ, "ENV_NAME": name, "COMPOSE_PROJECT_NAME": name}
    
    # Add port overrides if provided
    if ports:
        port_mapping = {
            "jupyter": "HOST_JUPYTER_PORT",
            "tensorboard": "HOST_TENSORBOARD_PORT", 
            "explainer": "HOST_EXPLAINER_PORT",
            "streamlit": "HOST_STREAMLIT_PORT",
        }
        for service, port in ports.items():
            if service in port_mapping:
                env[port_mapping[service]] = str(port)
    
    use_pty = force_pty or (os.name != "nt" and sys.stdin.isatty())

    if not use_pty and not getattr(_compose, "_warned", False):
        print("ℹ️  PTY not supported – running without TTY.")
        _compose._warned = True  # type: ignore[attr-defined]

    if rebuild:
        full_cmd = f"docker compose -p {name} {cmd} --build"
    else:
        full_cmd = f"docker compose -p {name} {cmd}"
    c.run(full_cmd, env=env, pty=use_pty)


@task(
    help={
        "name": "Project/venv name (defaults to folder name)",
        "use_pty": "Force PTY even on non-POSIX hosts",
        "jupyter_port": "Jupyter Lab port (default: 8890)",
        "tensorboard_port": "TensorBoard port (default: auto-assigned)",
        "explainer_port": "Explainer Dashboard port (default: auto-assigned)", 
        "streamlit_port": "Streamlit port (default: auto-assigned)",
        "mlflow_port": "MLflow UI port (default: 5000, auto-assigns if busy)",
    }
)
def up(
    c,
    name: Optional[str] = None,
    rebuild: bool = False,
    detach: bool = True,
    use_pty: bool = False,
    jupyter_port: Union[str, int, None] = None,
    tensorboard_port: Union[str, int, None] = None,
    explainer_port: Union[str, int, None] = None,
    streamlit_port: Union[str, int, None] = None,
    mlflow_port: Union[str, int, None] = None,
) -> None:
    """Build (optionally --rebuild) & start the container with custom ports."""
    name = name or BASE_ENV.name

    # ---------- Parse and validate all ports -----------------
    try:
        jupyter_port = _parse_port(jupyter_port)
        tensorboard_port = _parse_port(tensorboard_port)
        explainer_port = _parse_port(explainer_port)
        streamlit_port = _parse_port(streamlit_port)
        mlflow_port = _parse_port(mlflow_port)
    except ValueError as e:
        print(f"❌ Port validation failed: {e}")
        sys.exit(1)

    # ---------- build dynamic port map -----------------
    ports = {}
    if jupyter_port is not None:
        ports["jupyter"] = jupyter_port
    if tensorboard_port is not None:
        ports["tensorboard"] = tensorboard_port
    if explainer_port is not None:
        ports["explainer"] = explainer_port
    if streamlit_port is not None:
        ports["streamlit"] = streamlit_port

    # ---------- Explainer auto-assign (NEW) ------------
    print("DEBUG: Starting explainer port assignment")  # Debug
    try:
        # Try to use the explainer's version first
        from src.mlops.explainer import _first_free_port  # type: ignore
        print("DEBUG: Successfully imported _first_free_port from explainer")  # Debug
    except ModuleNotFoundError:
        print("DEBUG: Failed to import _first_free_port, using local implementation")  # Debug
        # We'll use our local _first_free_port implementation
        pass

    if explainer_port is None:
        print("DEBUG: No explainer port specified, finding one")  # Debug
        explainer_port = _find_port(8050, 5200)
    elif not _port_free("127.0.0.1", explainer_port):
        print(f"DEBUG: Specified explainer port {explainer_port} is in use")  # Debug
        sys.exit(1)
    ports["explainer"] = explainer_port
    print(f"🔌 Explainer host-port → {explainer_port}")

    # ----- MLflow auto-assign (default 5000) -----------
    print("DEBUG: Starting MLflow port assignment")  # Debug
    if mlflow_port is None:
        print("DEBUG: No MLflow port specified, finding one")  # Debug
        mlflow_port = _find_port(5000, 5200)
    elif not _port_free("127.0.0.1", mlflow_port):
        print(f"DEBUG: Specified MLflow port {mlflow_port} is in use")  # Debug
        sys.exit(1)
    ports["mlflow"] = mlflow_port
    print(f"🔌 MLflow host-port → {mlflow_port}")

    # Generate environment file
    env_path = _write_envfile(name, ports)
    compose_cmd = "up -d" if detach else "up"

    _compose(
        c,
        f"--env-file {env_path} {compose_cmd}",
        name,
        rebuild=rebuild,
        force_pty=use_pty,
        ports=ports,
    )


@task(
    help={
        "name": "Project/venv name (defaults to folder name)",
    }
)
def stop(c, name: Optional[str] = None) -> None:
    """Stop and remove dev container (keeps volumes)."""
    name = name or BASE_ENV.name
    cmd = f"docker compose -p {name} down"
    try:
        c.run(cmd)
        print(f"\n🛑 Stopped and removed project '{name}'")
    except Exception:
        print(f"❌ No running containers found for project '{name}'")


@task
def shell(c, name: str | None = None) -> None:
    """Open an interactive shell inside the running container."""
    name = name or BASE_ENV.name
    cmd = f"docker compose -p {name} ps -q datascience"
    cid = c.run(cmd, hide=True).stdout.strip()
    c.run(f"docker exec -it {cid} bash", env={"ENV_NAME": name}, pty=False)


@task
def clean(c) -> None:
    """Prune stopped containers + dangling images."""
    c.run("docker system prune -f")


@task
def ports(c, name: str | None = None) -> None:
    """Show current port mappings for the named project."""
    name = name or BASE_ENV.name
    cmd = f"docker compose -p {name} ps --format table"
    try:
        c.run(cmd, hide=False)
        print(f"\n📊 Port mappings for project '{name}':")
        print("=" * 50)
    except Exception:
        print(f"❌ No running containers found for project '{name}'")
        print("\n💡 Usage examples:")
        print("  invoke up --name myproject --jupyter-port 8891")
        print("  invoke up --name myproject --jupyter-port 8892 \\")
        print("    --tensorboard-port 6009")


# --- utilities ---------------------------------------------------------------
def _norm(path: str | pathlib.Path) -> str:
    """Return a lower-case, forward-slash, no-trailing-slash version of *path*."""
    p = str(path).replace("\\", "/").rstrip("/").lower()
    return p

def _docker_projects_from_this_repo() -> set[str]:
    """
    Discover every Compose *project name* whose working_dir label ends with
    the current repo path.

    Works across Windows ↔ WSL ↔ macOS because we do suffix-match on a
    normalised path.
    """
    here_tail = _norm(pathlib.Path(__file__).parent.resolve())
    cmd = (
        "docker container ls -a "
        "--format '{{.Label \"com.docker.compose.project\"}} "
        "{{.Label \"com.docker.compose.project.working_dir\"}}' "
        "--filter label=com.docker.compose.project"
    )
    projects: set[str] = set()
    for line in os.popen(cmd).read().strip().splitlines():
        try:
            proj, wd = line.split(maxsplit=1)
        except ValueError:
            continue
        if _norm(wd).endswith(here_tail):
            projects.add(proj)
    return projects

# --- task --------------------------------------------------------------------
@task(
    help={
        "name": "Project name (defaults to folder). Ignored with --all.",
        "all":  "Remove *all* projects launched from this repo.",
        "rmi":  "Image-removal policy: all | local | none (default: local).",
    }
)
def down(c, name: str | None = None, all: bool = False, rmi: str = "local"):
    """
    Stop containers **and** fully delete every artefact so next `invoke up`
    starts from a clean slate.

    Examples
    --------
    invoke down                  # nuke current-folder project
    invoke down --name ml_project --rmi all   # wipe everything for ml_project
    invoke down --all            # tear down every project from this repo
    """
    if rmi not in {"all", "local", "none"}:
        raise ValueError("--rmi must be all | local | none")

    targets = _docker_projects_from_this_repo() if all else {name or BASE_ENV.name}
    flags = "-v --remove-orphans"
    if rmi != "none":
        flags += f" --rmi {rmi}"

    for proj in targets:
        try:
            c.run(f"docker compose -p {proj} down {flags}")
            print(f"🗑️  Removed project '{proj}'")
        except Exception:
            print(f"⚠️  Nothing to remove for '{proj}'")


@task(
    help={
        "yaml": "Path to dashboard.yaml file",
        "port": "Port to serve on (default: 8150)",
        "host": "Host to bind to (default: 0.0.0.0)",
    }
)
def dashboard(c, yaml: str, port: int = 8150, host: str = "0.0.0.0") -> None:
    """
    Serve a saved ExplainerDashboard from a YAML configuration file.
    
    This task allows you to re-serve dashboards that were previously saved
    with build_and_log_dashboard(save_yaml=True).
    
    Examples:
        invoke dashboard --yaml dashboard.yaml
        invoke dashboard --yaml dashboard.yaml --port 8200
    """
    import sys
    from pathlib import Path
    from src.mlops.explainer import load_dashboard_yaml
    
    yaml_path = Path(yaml)
    if not yaml_path.exists():
        print(f"❌ Dashboard YAML file not found: {yaml_path}")
        sys.exit(1)
    
    # Check if port is available
    if not _port_free(host, port):
        print(f"❌ Port {port} is already in use on {host}")
        sys.exit(1)
    
    try:
        print(f"🔄 Loading dashboard from {yaml_path}")
        dashboard_obj = load_dashboard_yaml(yaml_path)
        
        print(f"🌐 Serving ExplainerDashboard on {host}:{port}")
        dashboard_obj.run(port=port, host=host, use_waitress=True, open_browser=False)
        
    except Exception as e:
        print(f"❌ Failed to load or serve dashboard: {e}")
        sys.exit(1)



Overwriting tasks.py


In [16]:
%%writefile tests/diagnose_devcontainer.py
#!/usr/bin/env python3
"""
Comprehensive diagnostic script for dev container issues.
Run this inside the container to diagnose Python environment and remote extension problems.
"""

import sys
import os
import subprocess
import json
from pathlib import Path


def run_command(cmd, description):
    """Run a command and return its output."""
    print(f"\n🔍 {description}")
    print("=" * 60)
    try:
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        if result.returncode == 0:
            print(f"✅ {result.stdout.strip()}")
        else:
            print(f"❌ Error (code {result.returncode}): {result.stderr.strip()}")
        return result
    except Exception as e:
        print(f"❌ Exception: {e}")
        return None


def check_paths_and_environment():
    """Check Python paths and environment variables."""
    print("\n🐍 PYTHON ENVIRONMENT DIAGNOSTICS")
    print("=" * 60)
    
    # Python executable and version
    print(f"Python executable: {sys.executable}")
    print(f"Python version: {sys.version}")
    print(f"Python path: {sys.path[:3]}...")  # First few paths
    
    # Environment variables
    print(f"\nVIRTUAL_ENV: {os.environ.get('VIRTUAL_ENV', 'Not set')}")
    print(f"PATH (first 3): {':'.join(os.environ.get('PATH', '').split(':')[:3])}")
    
    # Virtual environment validation
    venv_path = Path('/app/.venv')
    if venv_path.exists():
        print(f"✅ Virtual environment exists at {venv_path}")
        print(f"   - bin directory: {list(venv_path.glob('bin/python*'))}")
        print(f"   - site-packages: {(venv_path / 'lib/python3.10/site-packages').exists()}")
    else:
        print(f"❌ Virtual environment NOT found at {venv_path}")


def check_key_packages():
    """Check if key packages are importable."""
    print("\n📦 PACKAGE IMPORT TESTS")
    print("=" * 60)
    
    packages = [
        'jax', 'torch', 'numpy', 'pandas', 'matplotlib', 
        'jupyterlab', 'streamlit', 'sklearn'
    ]
    
    for package in packages:
        try:
            if package == 'sklearn':
                import sklearn
                version = sklearn.__version__
            else:
                module = __import__(package)
                version = getattr(module, '__version__', 'unknown')
            print(f"✅ {package}: {version}")
        except ImportError as e:
            print(f"❌ {package}: Import failed - {e}")
        except Exception as e:
            print(f"⚠️  {package}: {e}")


def check_gpu_environment():
    """Check GPU-related environment variables."""
    print("\n🎮 GPU ENVIRONMENT VARIABLES")
    print("=" * 60)
    
    gpu_env_vars = [
        'XLA_PYTHON_CLIENT_PREALLOCATE',
        'XLA_PYTHON_CLIENT_ALLOCATOR', 
        'XLA_PYTHON_CLIENT_MEM_FRACTION',
        'JAX_PLATFORM_NAME',
        'XLA_FLAGS',
        'JAX_DISABLE_JIT',
        'JAX_ENABLE_X64',
        'JAX_PREALLOCATION_SIZE_LIMIT_BYTES',
        'TF_FORCE_GPU_ALLOW_GROWTH',
        'NVIDIA_VISIBLE_DEVICES',
        'NVIDIA_DRIVER_CAPABILITIES'
    ]
    
    for var in gpu_env_vars:
        value = os.environ.get(var, 'Not set')
        print(f"   {var}: {value}")


def check_gpu_support():
    """Check GPU support for JAX and PyTorch with enhanced diagnostics."""
    print("\n🎮 ENHANCED GPU SUPPORT CHECK")
    print("=" * 60)
    
    # JAX GPU check with detailed info
    try:
        import jax
        print(f"JAX version: {jax.__version__}")
        
        devices = jax.devices()
        print(f"JAX devices: {devices}")
        
        if devices:
            for i, device in enumerate(devices):
                print(f"   Device {i}: {device}")
                
        if any('gpu' in str(device).lower() or 'cuda' in str(device).lower() for device in devices):
            print("✅ JAX GPU/CUDA support detected!")
            
            # Test a simple computation
            try:
                import jax.numpy as jnp
                x = jnp.ones((1000, 1000))
                result = jnp.sum(x)
                print(f"   ✅ JAX GPU computation test passed: sum = {result}")
            except Exception as e:
                print(f"   ⚠️  JAX GPU computation test failed: {e}")
        else:
            print("⚠️  JAX GPU support not detected")
            print("   This might be due to GPU architecture compatibility")
            
    except Exception as e:
        print(f"❌ JAX GPU check failed: {e}")
    
    # PyTorch GPU check with enhanced info
    try:
        import torch
        print(f"\nPyTorch version: {torch.__version__}")
        print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
        
        if torch.cuda.is_available():
            device_count = torch.cuda.device_count()
            print(f"✅ PyTorch CUDA device count: {device_count}")
            
            for i in range(device_count):
                try:
                    device_name = torch.cuda.get_device_name(i)
                    memory_total = torch.cuda.get_device_properties(i).total_memory
                    print(f"   Device {i}: {device_name}")
                    print(f"     Total memory: {memory_total / (1024**3):.1f} GB")
                except Exception as e:
                    print(f"   Device {i}: Error getting info - {e}")
            
            # Test a simple computation
            try:
                device = torch.device('cuda:0')
                x = torch.ones(1000, 1000, device=device)
                result = torch.sum(x)
                print(f"   ✅ PyTorch GPU computation test passed: sum = {result}")
            except Exception as e:
                print(f"   ⚠️  PyTorch GPU computation test failed: {e}")
        else:
            print("⚠️  PyTorch CUDA not available")
            print("   Check CUDA installation and GPU compatibility")
            
    except Exception as e:
        print(f"❌ PyTorch GPU check failed: {e}")


def check_workspace_mount():
    """Check if workspace is properly mounted."""
    print("\n📁 WORKSPACE MOUNT CHECK")
    print("=" * 60)
    
    workspace_path = Path('/workspace')
    if workspace_path.exists():
        print(f"✅ /workspace directory exists")
        try:
            contents = list(workspace_path.iterdir())[:10]  # First 10 items
            print(f"   Contents (first 10): {[p.name for p in contents]}")
            
            # Check for specific expected files
            expected_files = ['.devcontainer', 'pyproject.toml', 'docker-compose.yml']
            for file in expected_files:
                if (workspace_path / file).exists():
                    print(f"   ✅ Found: {file}")
                else:
                    print(f"   ❌ Missing: {file}")
        except Exception as e:
            print(f"   ❌ Error reading workspace: {e}")
    else:
        print(f"❌ /workspace directory does not exist")


def check_dev_container_config():
    """Check dev container configuration."""
    print("\n⚙️  DEV CONTAINER CONFIG CHECK")
    print("=" * 60)
    
    config_path = Path('/workspace/.devcontainer/devcontainer.json')
    if config_path.exists():
        print("✅ devcontainer.json found")
        try:
            with open(config_path) as f:
                config = json.load(f)
            print(f"   Name: {config.get('name', 'Not specified')}")
            print(f"   Python path: {config.get('customizations', {}).get('vscode', {}).get('settings', {}).get('python.defaultInterpreterPath', 'Not specified')}")
            print(f"   Workspace folder: {config.get('workspaceFolder', 'Not specified')}")
        except Exception as e:
            print(f"   ❌ Error reading config: {e}")
    else:
        print("❌ devcontainer.json not found")


def main():
    """Run all diagnostic checks."""
    print("🔍 DEV CONTAINER COMPREHENSIVE DIAGNOSTICS")
    print("=" * 80)
    print(f"Running from: {os.getcwd()}")
    print(f"User: {os.getenv('USER', 'unknown')}")
    print(f"Container hostname: {os.getenv('HOSTNAME', 'unknown')}")
    
    # System commands
    run_command("uv --version", "UV Version")
    run_command("which python", "Python Location")
    run_command("ls -la /app/.venv/", "Virtual Environment Contents")
    run_command("mount | grep workspace", "Workspace Mount Status")
    run_command("nvidia-smi", "NVIDIA GPU Status")
    
    # Python-based checks
    check_paths_and_environment()
    check_gpu_environment()
    check_key_packages()
    check_gpu_support()
    check_workspace_mount()
    check_dev_container_config()
    
    print("\n" + "=" * 80)
    print("🎯 SUMMARY & RECOMMENDATIONS")
    print("=" * 80)
    print("If you see issues:")
    print("1. ❌ Virtual env missing → Check Dockerfile uv sync step")
    print("2. ❌ Workspace not mounted → Check devcontainer.json mounts config")
    print("3. ❌ Packages missing → Check uv.lock and pip install steps")
    print("4. ⚠️  GPU not detected → Check docker-compose.yml gpu settings")
    print("5. 🔧 For VS Code issues → Check python.defaultInterpreterPath setting")
    print("6. 🎮 For GPU issues → Check NVIDIA drivers and CUDA compatibility")
    print("\n✅ All checks passed = ready for development!")


if __name__ == "__main__":
    main() 

Overwriting tests/diagnose_devcontainer.py


In [17]:
%%writefile tests/test_pytorch_jax_gpu.py
#!/usr/bin/env python3
"""
Test script to verify that PyTorch and JAX can access the GPU,
and that PyJAGS is working correctly.
"""

import sys


def test_pytorch_gpu():
    """Test PyTorch GPU availability and basic operations."""
    print("\n=== Testing PyTorch GPU ===")
    try:
        import torch
        print(f"PyTorch version: {torch.__version__}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        
        if not torch.cuda.is_available():
            print("❌ PyTorch CUDA not available!")
            return False
        
        print(f"CUDA device count: {torch.cuda.device_count()}")
        print(f"Current device: {torch.cuda.current_device()}")
        print(f"Device name: {torch.cuda.get_device_name(0)}")
        
        # Run a simple test computation
        x = torch.rand(1000, 1000).cuda()
        y = torch.rand(1000, 1000).cuda()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        
        start.record()
        z = torch.matmul(x, y)
        end.record()
        
        # Wait for GPU computation to finish
        torch.cuda.synchronize()
        print(f"Matrix multiplication time: {start.elapsed_time(end):.2f} ms")
        print(f"Result shape: {z.shape}")
        print("✅ PyTorch GPU test passed!")
        return True
    
    except ImportError:
        print("❌ PyTorch not found!")
        return False
    except Exception as e:
        print(f"❌ Error during PyTorch GPU test: {e}")
        return False


def test_jax_gpu():
    """Test JAX GPU availability and basic operations."""
    print("\n=== Testing JAX GPU ===")
    try:
        import jax
        import jax.numpy as jnp
        
        print(f"JAX version: {jax.__version__}")
        
        # Force GPU platform
        jax.config.update('jax_platform_name', 'gpu')
        
        # Get device count and details
        devices = jax.devices()
        device_count = len(devices)
        print(f"Available devices: {device_count}")
        
        for i, device in enumerate(devices):
            print(f"Device {i}: {device}")
        
        if device_count == 0 or 'cuda' not in str(devices[0]).lower():
            print("❌ No GPU devices found by JAX!")
            return False
        
        # Check CUDA configuration
        jit_info = jax.config.values
        print(f"JAX configuration: {jit_info}")
        
        # Run a simple GPU computation
        print("Running a test computation on GPU...")
        try:
            x = jnp.ones((1000, 1000))
            y = jnp.ones((1000, 1000))
            
            # Use JIT compilation for better performance
            @jax.jit
            def matmul(a, b):
                return jnp.matmul(a, b)
            
            result = matmul(x, y)
            print(f"Result shape: {result.shape}")
            
            print("✅ JAX GPU test passed!")
            return True
        except RuntimeError as e:
            if "ptxas too old" in str(e):
                print(f"⚠️ JAX GPU detected but CUDA compatibility issue: {e}")
                print("⚠️ JAX can see the GPU but there's a CUDA version compatibility issue.")
                print("⚠️ This is considered a partial success since the GPU is detected.")
                return True
            else:
                raise
    
    except ImportError:
        print("❌ JAX not found!")
        return False
    except Exception as e:
        print(f"❌ Error during JAX GPU test: {e}")
        return False


def test_pyjags():
    """Test PyJAGS installation and basic functionality."""
    print("\n=== Testing PyJAGS ===")
    try:
        import pyjags
        print(f"PyJAGS version: {pyjags.__version__}")
        
        # Create a simple model to verify that PyJAGS works
        code = """
        model {
            # Likelihood
            y ~ dnorm(mu, 1/sigma^2)
            
            # Priors
            mu ~ dnorm(0, 0.001)
            sigma ~ dunif(0, 100)
        }
        """
        
        # Sample data
        data = {'y': 0.5}
        
        # Initialize model with data
        model = pyjags.Model(code, data=data, chains=1, adapt=100)
        print("JAGS model initialized successfully!")
        
        # Sample from the model
        samples = model.sample(200, vars=['mu', 'sigma'])
        print("JAGS sampling completed successfully!")
        
        # Verify the samples
        mu_samples = samples['mu']
        sigma_samples = samples['sigma']
        print(f"mu mean: {mu_samples.mean():.4f}")
        print(f"sigma mean: {sigma_samples.mean():.4f}")
        
        print("✅ PyJAGS test passed!")
        return True
        
    except ImportError:
        print("❌ PyJAGS not found!")
        return False
    except Exception as e:
        print(f"❌ Error during PyJAGS test: {e}")
        return False


if __name__ == "__main__":
    print("Running GPU and PyJAGS verification tests...")
    
    pytorch_success = test_pytorch_gpu()
    jax_success = test_jax_gpu()
    pyjags_success = test_pyjags()
    
    print("\n=== Test Summary ===")
    print(f"PyTorch GPU: {'✅ PASS' if pytorch_success else '❌ FAIL'}")
    print(f"JAX GPU: {'✅ PASS' if jax_success else '❌ FAIL'}")
    print(f"PyJAGS: {'✅ PASS' if pyjags_success else '❌ FAIL'}")
    
    if pytorch_success and jax_success and pyjags_success:
        print("\n🎉 All tests passed! The container is working correctly.")
        sys.exit(0)
    else:
        print("\n❌ Some tests failed. Please check the output for details.")
        sys.exit(1)




Overwriting tests/test_pytorch_jax_gpu.py
