Skip to content

Commit

Permalink
Temporarily disable FLAX_LAZY_RNG to unbreak tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 429406755
  • Loading branch information
adarob authored and t5-copybara committed Feb 17, 2022
1 parent ba9f403 commit a7db934
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 2 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/build.yaml
Expand Up @@ -15,7 +15,10 @@ jobs:
run: |
pip install -e .[test] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
- name: Test with pytest
run: pytest
# TODO(adarob): Re-enable once tests are updated.
run: |
export FLAX_LAZY_RNG=no
pytest
# The below step just reports the success or failure of tests as a "commit status".
# This is needed for copybara integration.
- name: Report success or failure as github status
Expand Down
4 changes: 3 additions & 1 deletion t5x/eval.py
Expand Up @@ -24,11 +24,13 @@
import os
from typing import Optional, Sequence, Type

from absl import logging
# Set Linen to add profiling information when constructing Modules.
# Must be set before flax imports.
# pylint:disable=g-import-not-at-top
os.environ['FLAX_PROFILE'] = 'true'
# TODO(adarob): Re-enable once users are notified and tests are updated.
os.environ['FLAX_LAZY_RNG'] = 'no'
from absl import logging
from clu import metric_writers
import jax
import seqio
Expand Down
6 changes: 6 additions & 0 deletions t5x/infer.py
Expand Up @@ -30,6 +30,12 @@
import time
from typing import Any, Callable, Iterator, List, Mapping, Optional, Sequence, Tuple, Type

# Set Linen to add profiling information when constructing Modules.
# Must be set before flax imports.
# pylint:disable=g-import-not-at-top
os.environ['FLAX_PROFILE'] = 'true'
# TODO(adarob): Re-enable once users are notified and tests are updated.
os.environ['FLAX_LAZY_RNG'] = 'no'
from absl import logging
import jax
import jax.numpy as jnp
Expand Down
6 changes: 6 additions & 0 deletions t5x/train.py
Expand Up @@ -22,6 +22,12 @@
import time
from typing import Callable, Iterator, Sequence, Mapping, Tuple, Type, Optional

# Set Linen to add profiling information when constructing Modules.
# Must be set before flax imports.
# pylint:disable=g-import-not-at-top
os.environ['FLAX_PROFILE'] = 'true'
# TODO(adarob): Re-enable once users are notified and tests are updated.
os.environ['FLAX_LAZY_RNG'] = 'no'
from absl import logging
from clu import metric_writers
import jax
Expand Down

0 comments on commit a7db934

Please sign in to comment.