Skip to content

Commit

Permalink
feat(train): use compilation cache
Browse files Browse the repository at this point in the history
  • Loading branch information
borisdayma committed Feb 7, 2022
1 parent 68cc185 commit da9367c
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tools/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from flax.training import train_state
from flax.training.common_utils import onehot
from jax.experimental import PartitionSpec, maps
from jax.experimental.compilation_cache import compilation_cache as cc
from jax.experimental.pjit import pjit, with_sharding_constraint
from tqdm import tqdm
from transformers import HfArgumentParser
Expand All @@ -53,6 +54,11 @@
set_partitions,
)

cc.initialize_cache(
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2**30
)


logger = logging.getLogger(__name__)


Expand Down

0 comments on commit da9367c

Please sign in to comment.