#Colab GPU Environment Set Up

Creates a Python 3.10 kernel within Colab, clones the repository and installs pinned dependencies. Installs compatible versions of JAX (CUDA) and NumPyro and CUDA JAX and NVIDIA CUDA wheels.
Creates a py310cuda launcher fixed on the correct libraries.

In [None]:
import os, subprocess, sys, stat, textwrap, pathlib

REPO = "/content/sgfa_qmap-pd"

In [None]:
# Clone the repository 
os.chdir("/content")
subprocess.run(["rm","-rf","sgfa_qmap-pd"])
subprocess.check_call(["git","clone","https://github.com/meeramads/sgfa_qmap-pd.git"])
os.chdir(REPO)

In [None]:
# Install Python 3.10 side-by-side
subprocess.check_call(["wget","-q","https://github.com/korakot/kora/releases/download/v0.10/py310.sh"])
subprocess.check_call(["bash","./py310.sh","-b","-f","-p","/usr/local"])
subprocess.check_call(["python3.10","-V"])

In [None]:
# Install dependencies
subprocess.check_call(["python3.10","-m","pip","install","-U","pip"])
subprocess.check_call(["python3.10","-m","pip","install","-r","requirements.txt"])

In [None]:
# Set up JAX (CUDA) + NumPyro to ensure compatibility
subprocess.run(["python3.10","-m","pip","uninstall","-y","jax","jaxlib"])
subprocess.check_call([
    "python3.10","-m","pip","install","-U",
    "jax[cuda12_pip]==0.4.20","-f","https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
])
subprocess.check_call(["python3.10","-m","pip","install","numpyro==0.13.2"])

In [None]:
# NVIDIA CUDA libraries into the *py310* site-packages
subprocess.check_call(["python3.10","-m","pip","install","-q",
    "nvidia-cudnn-cu12>=8.9,<9",
    "nvidia-cublas-cu12>=12.2",
    "nvidia-cuda-runtime-cu12>=12.2",
    "nvidia-cusolver-cu12>=11.4",
    "nvidia-cusparse-cu12>=12.1",
    "nvidia-cufft-cu12>=11.0",
    "nvidia-cuda-cupti-cu12>=12.2",
    "nvidia-nvjitlink-cu12>=12.2",
    "nvidia-nccl-cu12>=2.18",
])

In [None]:
# Build LD_LIBRARY_PATH for those wheels and write a launcher
py310_site = subprocess.check_output(
    ["python3.10","-c","import site; print(site.getsitepackages()[0])"],
    text=True
).strip()
subdirs = ["cudnn/lib","cublas/lib","cufft/lib","cusolver/lib","cusparse/lib",
           "cuda_runtime/lib","cuda_cupti/lib","nvjitlink/lib","nccl/lib"]
lib_paths = [os.path.join(py310_site,"nvidia",d) for d in subdirs]
lib_paths = [p for p in lib_paths if os.path.isdir(p)]
LD = ":".join(lib_paths)

wrapper = "/usr/local/bin/py310cuda"
pathlib.Path(wrapper).write_text(textwrap.dedent(f"""\
#!/bin/bash
export LD_LIBRARY_PATH="{LD}:$LD_LIBRARY_PATH"
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.70
export JAX_PLATFORM_NAME=gpu
exec python3.10 "$@"
"""))
os.chmod(wrapper, os.stat(wrapper).st_mode | stat.S_IEXEC)

In [None]:
# Verify JAX sees the GPU
subprocess.check_call([
    "py310cuda","-c",
    "import jax; print('backend:', jax.lib.xla_bridge.get_backend().platform, '| devices:', jax.devices())"
])

In [None]:
# Double check
!py310cuda -c "import jax; print('backend:', jax.lib.xla_bridge.get_backend().platform, '| devices:', jax.devices())"

---
# Training the model

Call  ```!py310cuda run_analysis.py``` with the flag ```--device gpu```.

Run ```!py310cuda run_analysis.py --help || py310cuda run_analysis.py -h``` for detailed information on other available flags.