# Troubleshooting & Environment


**Goal:** quickly diagnose common environment issues (JAX CPU/GPU, NVIDIA driver, optional hyperâ€‘optimizers).


In [None]:

import os, sys, importlib.util, subprocess

def _sh(cmd):
    try:
        out = subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT, text=True, timeout=8)
        print(out)
    except Exception as e:
        print(f"[cmd failed] {cmd}\n{e}")

print("Python:", sys.version)
print("venv:", sys.prefix)

try:
    import jax
    print("JAX devices:", jax.devices())
except Exception as e:
    print("JAX not importable:", e)

print("nvidia-smi (optional):")
_sh("nvidia-smi | head -n 10")

for mod in ["kahypar", "optuna", "nevergrad", "cmaes"]:
    print(f"{mod:10s}:", importlib.util.find_spec(mod) is not None)

print("\nSet before Python starts to force device:")
print("  export JAX_PLATFORMS=cpu   # or: cuda")
print("Optional to manage VRAM:")
print("  export XLA_PYTHON_CLIENT_PREALLOCATE=false")
print("  export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85")
