Skip to content

Commit

Permalink
Jax persistent compilation cache user guide.
Browse files Browse the repository at this point in the history
This user guide covers using the cache on local filesystems
and Google Cloud.

PiperOrigin-RevId: 623236335
  • Loading branch information
jax authors committed Apr 9, 2024
1 parent 828e60c commit afb775c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 1 deletion.
74 changes: 74 additions & 0 deletions docs/persistent_compilation_cache.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Persistent Compilation Cache

JAX has an optional disk cache for compiled programs. If enabled, JAX will
store copies of compiled programs on disk, which can save recompilation time
when running the same or similar tasks repeatedly.

## Usage

The compilation cache is enabled when the
[cache-location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
is set. This should be done prior to the first compilation. Set the location as
follows:

```
import jax
# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location")
```

See the sections below for more detail on `cache-location`.

[`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
is an alternate way of setting `cache-location`.

### Local filesystem

`cache-location` can be a directory on the local filesystem. For example:

```
import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
```

Note: the cache does not have an eviction mechanism implemented. If the
cache-location is a directory in the local filesystem, its size will continue
to grow unless files are manually deleted.

### Google Cloud

When running on Google Cloud, the compilation cache can be placed on a Google
Cloud Storage (GCS) bucket. We recommend the following configuration:

* Create the bucket in the same region as where the workload will run.

* Create the bucket in the same project as the workload’s VM(s). Ensure that
permissions are set so that the VM(s) can write to the bucket.

* There is no need for replication for smaller workloads. Larger workloads
could benefit from replication.

* Use “Standard” for the default storage class for the bucket.

* Set the soft delete policy to its shortest: 7 days.

* Set the object lifecycle to the expected duration of the workload run.
For example, if the workload is expected to run for 10 days, set the object
lifecycle to 10 days. That should cover restarts that occur during the entire
run. Use `age` for the lifecycle condition and `Delete` for the action. See
[Object Lifecycle Management](https://cloud.google.com/storage/docs/lifecycle)
for details. If the object lifecycle is not set, the cache will continue to
grow since there is no eviction mechanism implemented.

* All encryption policies are supported.

Assuming that `gs://jax-cache` is the GCS bucket, set `cache-location` as
follows:

```
import jax
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
```
2 changes: 1 addition & 1 deletion docs/user_guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ or deployed codebases.
device_memory_profiling
debugging/index
gpu_performance_tips

persistent_compilation_cache

.. toctree::
:maxdepth: 1
Expand Down

0 comments on commit afb775c

Please sign in to comment.